diff --git a/components/brave_wallet/browser/BUILD.gn b/components/brave_wallet/browser/BUILD.gn index c322e93d8661..6ca8fca00b6e 100644 --- a/components/brave_wallet/browser/BUILD.gn +++ b/components/brave_wallet/browser/BUILD.gn @@ -312,10 +312,14 @@ static_library("browser") { if (enable_orchard) { sources += [ + "zcash/orchard_shard_tree_delegate_impl.cc", + "zcash/orchard_shard_tree_delegate_impl.h", "zcash/zcash_create_shield_transaction_task.cc", "zcash/zcash_create_shield_transaction_task.h", "zcash/zcash_orchard_storage.cc", "zcash/zcash_orchard_storage.h", + "zcash/zcash_orchard_sync_state.cc", + "zcash/zcash_orchard_sync_state.h", "zcash/zcash_shield_sync_service.cc", "zcash/zcash_shield_sync_service.h", ] diff --git a/components/brave_wallet/browser/internal/BUILD.gn b/components/brave_wallet/browser/internal/BUILD.gn index c957e9fbd3b8..6609d09bf387 100644 --- a/components/brave_wallet/browser/internal/BUILD.gn +++ b/components/brave_wallet/browser/internal/BUILD.gn @@ -39,12 +39,22 @@ source_set("hd_key") { } if (enable_orchard) { + source_set("test_support") { + sources = [ + "orchard_test_utils.cc", + "orchard_test_utils.h", + ] + deps = [ "//brave/components/brave_wallet/browser/zcash/rust" ] + } + source_set("orchard_bundle") { sources = [ "orchard_block_scanner.cc", "orchard_block_scanner.h", "orchard_bundle_manager.cc", "orchard_bundle_manager.h", + "orchard_shard_tree_manager.cc", + "orchard_shard_tree_manager.h", ] deps = [ "//brave/components/brave_wallet/browser/zcash/rust" ] } diff --git a/components/brave_wallet/browser/internal/orchard_block_scanner.cc b/components/brave_wallet/browser/internal/orchard_block_scanner.cc index 186d740b0c34..df71d63897cd 100644 --- a/components/brave_wallet/browser/internal/orchard_block_scanner.cc +++ b/components/brave_wallet/browser/internal/orchard_block_scanner.cc @@ -5,19 +5,28 @@ #include "brave/components/brave_wallet/browser/internal/orchard_block_scanner.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h" + namespace brave_wallet { OrchardBlockScanner::Result::Result() = default; -OrchardBlockScanner::Result::Result(std::vector discovered_notes, - std::vector spent_notes) +OrchardBlockScanner::Result::Result( + std::vector discovered_notes, + std::vector spent_notes, + std::unique_ptr scanned_blocks) : discovered_notes(std::move(discovered_notes)), - spent_notes(std::move(spent_notes)) {} - -OrchardBlockScanner::Result::Result(const Result&) = default; + found_spends(std::move(spent_notes)), + scanned_blocks(std::move(scanned_blocks)) {} +OrchardBlockScanner::Result::Result(OrchardBlockScanner::Result&&) = default; OrchardBlockScanner::Result& OrchardBlockScanner::Result::operator=( - const Result&) = default; + OrchardBlockScanner::Result&&) = default; + +// OrchardBlockScanner::Result::Result(const Result&) = default; + +// OrchardBlockScanner::Result& OrchardBlockScanner::Result::operator=( +// const Result&) = default; OrchardBlockScanner::Result::~Result() = default; @@ -29,24 +38,24 @@ OrchardBlockScanner::~OrchardBlockScanner() = default; base::expected OrchardBlockScanner::ScanBlocks( - std::vector known_notes, + OrchardTreeState tree_state, std::vector blocks) { - std::vector found_nullifiers; - std::vector found_notes; + std::unique_ptr result = + decoder_->ScanBlocks(tree_state, blocks); + if (!result) { + DVLOG(1) << "Failed to parse block range."; + return base::unexpected(ErrorCode::kInputError); + } + + if (!result->GetDiscoveredNotes()) { + DVLOG(1) << "Failed to resolve discovered notes."; + return base::unexpected(ErrorCode::kInputError); + } + + std::vector found_spends; + std::vector found_notes = result->GetDiscoveredNotes().value(); for (const auto& block : blocks) { - // Scan block using the decoder initialized with the provided fvk - // to find new spendable notes. - auto scan_result = decoder_->ScanBlock(block); - if (!scan_result) { - return base::unexpected(ErrorCode::kDecoderError); - } - found_notes.insert(found_notes.end(), scan_result->begin(), - scan_result->end()); - // Place found notes to the known notes list so we can also check for - // nullifiers - known_notes.insert(known_notes.end(), scan_result->begin(), - scan_result->end()); for (const auto& tx : block->vtx) { // We only scan orchard actions here for (const auto& orchard_action : tx->orchard_actions) { @@ -54,25 +63,31 @@ OrchardBlockScanner::ScanBlocks( return base::unexpected(ErrorCode::kInputError); } - std::array action_nullifier; - base::ranges::copy(orchard_action->nullifier, action_nullifier.begin()); - + OrchardNoteSpend spend; // Nullifier is a public information about some note being spent. - // Here we are trying to find a known spendable notes which nullifier + // -- Here we are trying to find a known spendable notes which nullifier // matches nullifier from the processed transaction. - if (std::find_if(known_notes.begin(), known_notes.end(), - [&action_nullifier](const auto& v) { - return v.nullifier == action_nullifier; - }) != known_notes.end()) { - OrchardNullifier nullifier; - nullifier.block_id = block->height; - nullifier.nullifier = action_nullifier; - found_nullifiers.push_back(std::move(nullifier)); - } + base::ranges::copy(orchard_action->nullifier, spend.nullifier.begin()); + spend.block_id = block->height; + found_spends.push_back(std::move(spend)); } } } - return Result({std::move(found_notes), std::move(found_nullifiers)}); + + return Result( + {std::move(found_notes), std::move(found_spends), std::move(result)}); +} + +// static +OrchardBlockScanner::Result OrchardBlockScanner::CreateResultForTesting( + const OrchardTreeState& tree_state, + const std::vector& commitments) { + auto builder = orchard::OrchardDecodedBlocksBundle::CreateTestingBuilder(); + for (const auto& commitment : commitments) { + builder->AddCommitment(commitment); + } + builder->SetPriorTreeState(tree_state); + return Result{{}, {}, builder->Complete()}; } } // namespace brave_wallet diff --git a/components/brave_wallet/browser/internal/orchard_block_scanner.h b/components/brave_wallet/browser/internal/orchard_block_scanner.h index e5380b02edd4..f4e27adb55d1 100644 --- a/components/brave_wallet/browser/internal/orchard_block_scanner.h +++ b/components/brave_wallet/browser/internal/orchard_block_scanner.h @@ -12,7 +12,9 @@ #include #include "base/types/expected.h" +#include "brave/components/brave_wallet/browser/internal/orchard_block_scanner.h" #include "brave/components/brave_wallet/browser/zcash/rust/orchard_block_decoder.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h" #include "brave/components/brave_wallet/common/zcash_utils.h" namespace brave_wallet { @@ -26,15 +28,19 @@ class OrchardBlockScanner { struct Result { Result(); Result(std::vector discovered_notes, - std::vector spent_notes); - Result(const Result&); - Result& operator=(const Result&); + std::vector spent_notes, + std::unique_ptr scanned_blocks); + Result(const Result&) = delete; + Result& operator=(const Result&) = delete; + Result(Result&&); + Result& operator=(Result&&); ~Result(); // New notes have been discovered std::vector discovered_notes; // Nullifiers for the previously discovered notes - std::vector spent_notes; + std::vector found_spends; + std::unique_ptr scanned_blocks; }; explicit OrchardBlockScanner(const OrchardFullViewKey& full_view_key); @@ -43,9 +49,13 @@ class OrchardBlockScanner { // Scans blocks to find incoming notes related to fvk // Also checks whether existing notes were spent. virtual base::expected ScanBlocks( - std::vector known_notes, + OrchardTreeState tree_state, std::vector blocks); + static Result CreateResultForTesting( + const OrchardTreeState& tree_state, + const std::vector& commitments); + private: std::unique_ptr decoder_; }; diff --git a/components/brave_wallet/browser/internal/orchard_block_scanner_unittest.cc b/components/brave_wallet/browser/internal/orchard_block_scanner_unittest.cc index 9fae8b28942c..c3df069ef84f 100644 --- a/components/brave_wallet/browser/internal/orchard_block_scanner_unittest.cc +++ b/components/brave_wallet/browser/internal/orchard_block_scanner_unittest.cc @@ -189,7 +189,7 @@ TEST(OrchardBlockScannerTest, DiscoverNewNotes) { EXPECT_EQ(result.value().discovered_notes[3].block_id, 11u); EXPECT_EQ(result.value().discovered_notes[3].amount, 2549979667u); - EXPECT_EQ(result.value().spent_notes.size(), 0u); + EXPECT_EQ(result.value().found_spends.size(), 5u); } TEST(OrchardBlockScannerTest, WrongInput) { @@ -469,11 +469,13 @@ TEST(OrchardBlockScanner, FoundKnownNullifiers_SameBatch) { EXPECT_EQ(result.value().discovered_notes[0].block_id, 10u); EXPECT_EQ(result.value().discovered_notes[0].amount, 3625561528u); - EXPECT_EQ(result.value().spent_notes.size(), 1u); - EXPECT_EQ(result.value().spent_notes[0].block_id, 11u); + EXPECT_EQ(result.value().found_spends.size(), 2u); + EXPECT_EQ(result.value().found_spends[0].block_id, 10u); + EXPECT_EQ(result.value().found_spends[1].block_id, 11u); + EXPECT_EQ( - std::vector(result.value().spent_notes[0].nullifier.begin(), - result.value().spent_notes[0].nullifier.end()), + std::vector(result.value().found_spends[1].nullifier.begin(), + result.value().found_spends[1].nullifier.end()), PrefixedHexStringToBytes( "0x6588cc7fabfab2b2a4baa89d4dfafaa50cc89d22f96d10fb7689461b921ad40d") .value()); @@ -499,9 +501,9 @@ TEST(OrchardBlockScanner, FoundKnownNullifiers) { PrefixedHexStringToBytes( "0x1b32edbbe4d18f28876de262518ad31122701f8c0a52e98047a337876e7eea19") .value(); - OrchardNullifier nf; - base::ranges::copy(nullifier_bytes, nf.nullifier.begin()); - nf.block_id = 10; + OrchardNoteSpend spend; + base::ranges::copy(nullifier_bytes, spend.nullifier.begin()); + spend.block_id = 10; action->nullifier = nullifier_bytes; action->ciphertext = std::vector(kOrchardCipherTextSize, 0); @@ -525,11 +527,13 @@ TEST(OrchardBlockScanner, FoundKnownNullifiers) { notes.push_back(note); blocks.push_back(std::move(block)); - auto result = scanner.ScanBlocks(std::move(notes), std::move(blocks)); + OrchardTreeState tree_state; + + auto result = scanner.ScanBlocks(tree_state, std::move(blocks)); EXPECT_TRUE(result.has_value()); - EXPECT_EQ(result.value().spent_notes.size(), 1u); - EXPECT_EQ(result.value().spent_notes[0], nf); + EXPECT_EQ(result.value().found_spends.size(), 1u); + EXPECT_EQ(result.value().found_spends[0].nullifier, spend.nullifier); EXPECT_EQ(result.value().discovered_notes.size(), 0u); } diff --git a/components/brave_wallet/browser/internal/orchard_shard_tree_manager.cc b/components/brave_wallet/browser/internal/orchard_shard_tree_manager.cc new file mode 100644 index 000000000000..0ca1201b68bd --- /dev/null +++ b/components/brave_wallet/browser/internal/orchard_shard_tree_manager.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/brave_wallet/browser/internal/orchard_shard_tree_manager.h" + +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet { + +// static +std::unique_ptr OrchardShardTreeManager::Create( + std::unique_ptr delegate) { + auto shard_tree = orchard::OrchardShardTree::Create(std::move(delegate)); + if (!shard_tree) { + return nullptr; + } + return std::make_unique(std::move(shard_tree)); +} + +// static +std::unique_ptr +OrchardShardTreeManager::CreateForTesting( + std::unique_ptr delegate) { + auto shard_tree = + orchard::OrchardShardTree::CreateForTesting(std::move(delegate)); + if (!shard_tree) { + return nullptr; + } + return std::make_unique(std::move(shard_tree)); +} + +OrchardShardTreeManager::OrchardShardTreeManager( + std::unique_ptr<::brave_wallet::orchard::OrchardShardTree> shard_tree) { + orchard_shard_tree_ = std::move(shard_tree); +} + +OrchardShardTreeManager::~OrchardShardTreeManager() {} + +bool OrchardShardTreeManager::InsertCommitments( + OrchardBlockScanner::Result result) { + return orchard_shard_tree_->ApplyScanResults( + std::move(result.scanned_blocks)); +} + +base::expected, std::string> +OrchardShardTreeManager::CalculateWitness(std::vector notes, + uint32_t checkpoint_position) { + for (auto& input : notes) { + auto witness = orchard_shard_tree_->CalculateWitness( + input.note.orchard_commitment_tree_position, checkpoint_position); + if (!witness.has_value()) { + return base::unexpected(witness.error()); + } + input.witness = witness.value(); + } + return notes; +} + +bool OrchardShardTreeManager::Truncate(uint32_t checkpoint) { + return orchard_shard_tree_->TruncateToCheckpoint(checkpoint); +} + +} // namespace brave_wallet diff --git a/components/brave_wallet/browser/internal/orchard_shard_tree_manager.h b/components/brave_wallet/browser/internal/orchard_shard_tree_manager.h new file mode 100644 index 000000000000..a8ffed320096 --- /dev/null +++ b/components/brave_wallet/browser/internal/orchard_shard_tree_manager.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_INTERNAL_ORCHARD_SHARD_TREE_MANAGER_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_INTERNAL_ORCHARD_SHARD_TREE_MANAGER_H_ + +#include +#include +#include +#include + +#include "base/types/expected.h" +#include "brave/components/brave_wallet/browser/internal/orchard_block_scanner.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_shard_tree.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet { + +class OrchardShardTreeManager { + public: + OrchardShardTreeManager( + std::unique_ptr<::brave_wallet::orchard::OrchardShardTree> shard_tree); + ~OrchardShardTreeManager(); + bool InsertCommitments(OrchardBlockScanner::Result commitments); + base::expected, std::string> CalculateWitness( + std::vector notes, + uint32_t checkpoint_position); + bool Truncate(uint32_t checkpoint); + base::expected VerifyCheckpoint(); + + static std::unique_ptr Create( + std::unique_ptr delegate); + + // Creates shard tree size of 8 for testing + static std::unique_ptr CreateForTesting( + std::unique_ptr delegate); + + private: + std::unique_ptr<::brave_wallet::orchard::OrchardShardTree> + orchard_shard_tree_; +}; + +} // namespace brave_wallet + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_INTERNAL_ORCHARD_SHARD_TREE_MANAGER_H_ diff --git a/components/brave_wallet/browser/internal/orchard_test_utils.cc b/components/brave_wallet/browser/internal/orchard_test_utils.cc new file mode 100644 index 000000000000..293ae015ce39 --- /dev/null +++ b/components/brave_wallet/browser/internal/orchard_test_utils.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/brave_wallet/browser/internal/orchard_test_utils.h" + +#include "base/memory/ptr_util.h" + +namespace brave_wallet { + +OrchardTestUtils::OrchardTestUtils() { + orchard_test_utils_impl_ = orchard::OrchardTestUtils::Create(); +} + +OrchardTestUtils::~OrchardTestUtils() {} + +OrchardCommitmentValue OrchardTestUtils::CreateMockCommitmentValue( + uint32_t position, + uint32_t rseed) { + return orchard_test_utils_impl_->CreateMockCommitmentValue(position, rseed); +} + +} // namespace brave_wallet diff --git a/components/brave_wallet/browser/internal/orchard_test_utils.h b/components/brave_wallet/browser/internal/orchard_test_utils.h new file mode 100644 index 000000000000..06867038897b --- /dev/null +++ b/components/brave_wallet/browser/internal/orchard_test_utils.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_INTERNAL_ORCHARD_TEST_UTILS_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_INTERNAL_ORCHARD_TEST_UTILS_H_ + +#include + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_test_utils.h" + +namespace brave_wallet { + +class OrchardTestUtils { + public: + OrchardTestUtils(); + ~OrchardTestUtils(); + + OrchardCommitmentValue CreateMockCommitmentValue(uint32_t position, + uint32_t rseed); + + private: + std::unique_ptr orchard_test_utils_impl_; +}; + +} // namespace brave_wallet + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_INTERNAL_ORCHARD_TEST_UTILS_H_ diff --git a/components/brave_wallet/browser/test/BUILD.gn b/components/brave_wallet/browser/test/BUILD.gn index 1cb94f2a57e1..de1d9c25e2da 100644 --- a/components/brave_wallet/browser/test/BUILD.gn +++ b/components/brave_wallet/browser/test/BUILD.gn @@ -107,7 +107,8 @@ source_set("brave_wallet_unit_tests") { "//brave/components/brave_wallet/browser/wallet_data_files_installer_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_grpc_utils_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_keyring_unittest.cc", - "//brave/components/brave_wallet/browser/zcash/zcash_serializer_unittest.cc", + + # "//brave/components/brave_wallet/browser/zcash/zcash_serializer_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_transaction_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_transaction_utils_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_wallet_service_unittest.cc", @@ -161,11 +162,14 @@ source_set("brave_wallet_unit_tests") { "//brave/components/brave_wallet/browser/internal/hd_key_zip32_unittest.cc", "//brave/components/brave_wallet/browser/internal/orchard_block_scanner_unittest.cc", "//brave/components/brave_wallet/browser/internal/orchard_bundle_manager_unittest.cc", + "//brave/components/brave_wallet/browser/zcash/orchard_shard_tree_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_orchard_storage_unittest.cc", "//brave/components/brave_wallet/browser/zcash/zcash_shield_sync_service_unittest.cc", ] - deps += - [ "//brave/components/brave_wallet/browser/internal:orchard_bundle" ] + deps += [ + "//brave/components/brave_wallet/browser/internal:orchard_bundle", + "//brave/components/brave_wallet/browser/internal:test_support", + ] } if (!is_ios) { diff --git a/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.cc b/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.cc new file mode 100644 index 000000000000..2eb81e203ec0 --- /dev/null +++ b/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.cc @@ -0,0 +1,196 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.h" + +namespace brave_wallet { + +namespace { + +OrchardShardTreeDelegate::Error From(ZCashOrchardStorage::Error) { + return OrchardShardTreeDelegate::Error::kStorageError; +} + +} // namespace + +OrchardShardTreeDelegateImpl::OrchardShardTreeDelegateImpl( + mojom::AccountIdPtr account_id, + scoped_refptr storage) + : account_id_(std::move(account_id)), storage_(storage) {} + +OrchardShardTreeDelegateImpl::~OrchardShardTreeDelegateImpl() {} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetCap() const { + auto result = storage_->GetCap(account_id_.Clone()); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(*result); +} + +base::expected +OrchardShardTreeDelegateImpl::PutCap(OrchardCap cap) { + auto result = storage_->PutCap(account_id_.Clone(), cap); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return *result; +} + +base::expected +OrchardShardTreeDelegateImpl::Truncate(uint32_t block_height) { + auto result = storage_->TruncateShards(account_id_.Clone(), block_height); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return *result; +} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetLatestShardIndex() const { + auto result = storage_->GetLatestShardIndex(account_id_.Clone()); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return *result; +} + +base::expected +OrchardShardTreeDelegateImpl::PutShard(OrchardShard shard) { + auto result = storage_->PutShard(account_id_.Clone(), shard); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return *result; +} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetShard(OrchardShardAddress address) const { + auto result = storage_->GetShard(account_id_.Clone(), address); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(*result); +} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::LastShard(uint8_t shard_height) const { + auto result = storage_->LastShard(account_id_.Clone(), shard_height); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(*result); +} + +base::expected +OrchardShardTreeDelegateImpl::CheckpointCount() const { + auto result = storage_->CheckpointCount(account_id_.Clone()); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return *result; +} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::MinCheckpointId() const { + auto result = storage_->MinCheckpointId(account_id_.Clone()); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::MaxCheckpointId() const { + auto result = storage_->MaxCheckpointId(account_id_.Clone()); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected, OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetCheckpointAtDepth(uint32_t depth) const { + auto result = storage_->GetCheckpointAtDepth(account_id_.Clone(), depth); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected, + OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetCheckpoint(uint32_t checkpoint_id) const { + auto result = storage_->GetCheckpoint(account_id_.Clone(), checkpoint_id); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected, + OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetCheckpoints(size_t limit) const { + auto result = storage_->GetCheckpoints(account_id_.Clone(), limit); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected +OrchardShardTreeDelegateImpl::AddCheckpoint(uint32_t id, + OrchardCheckpoint checkpoint) { + auto result = storage_->AddCheckpoint(account_id_.Clone(), id, checkpoint); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected +OrchardShardTreeDelegateImpl::TruncateCheckpoints(uint32_t checkpoint_id) { + auto result = + storage_->TruncateCheckpoints(account_id_.Clone(), checkpoint_id); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected +OrchardShardTreeDelegateImpl::RemoveCheckpoint(uint32_t checkpoint_id) { + auto result = storage_->RemoveCheckpoint(account_id_.Clone(), checkpoint_id); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected +OrchardShardTreeDelegateImpl::RemoveCheckpointAt(uint32_t depth) { + return false; +} + +base::expected, + OrchardShardTreeDelegate::Error> +OrchardShardTreeDelegateImpl::GetShardRoots(uint8_t shard_level) const { + auto result = storage_->GetShardRoots(account_id_.Clone(), shard_level); + if (!result.has_value()) { + return base::unexpected(From(result.error())); + } + return std::move(result.value()); +} + +base::expected +OrchardShardTreeDelegateImpl::UpdateCheckpoint(uint32_t id, + OrchardCheckpoint checkpoint) { + // RemoveCheckpoint(id); + // AddCheckpoint(id, checkpoint); + return false; +} + +} // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.h b/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.h new file mode 100644 index 000000000000..f30d38fd3124 --- /dev/null +++ b/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ORCHARD_SHARD_TREE_DELEGATE_IMPL_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ORCHARD_SHARD_TREE_DELEGATE_IMPL_H_ + +#include +#include + +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_storage.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet { + +class OrchardShardTreeDelegateImpl : public OrchardShardTreeDelegate { + public: + OrchardShardTreeDelegateImpl(mojom::AccountIdPtr account_id, + scoped_refptr storage); + ~OrchardShardTreeDelegateImpl() override; + + base::expected, Error> GetCap() const override; + base::expected PutCap(OrchardCap cap) override; + base::expected Truncate(uint32_t block_height) override; + base::expected, Error> GetLatestShardIndex() + const override; + base::expected PutShard(OrchardShard shard) override; + base::expected, Error> GetShard( + OrchardShardAddress address) const override; + base::expected, Error> LastShard( + uint8_t shard_height) const override; + base::expected CheckpointCount() const override; + base::expected, Error> MinCheckpointId() + const override; + base::expected, Error> MaxCheckpointId() + const override; + base::expected, Error> GetCheckpointAtDepth( + uint32_t depth) const override; + base::expected, Error> GetCheckpoint( + uint32_t checkpoint_id) const override; + base::expected, Error> GetCheckpoints( + size_t limit) const override; + base::expected AddCheckpoint( + uint32_t id, + OrchardCheckpoint checkpoint) override; + base::expected TruncateCheckpoints( + uint32_t checkpoint_id) override; + base::expected RemoveCheckpoint(uint32_t checkpoint_id) override; + base::expected RemoveCheckpointAt(uint32_t depth) override; + base::expected, Error> GetShardRoots( + uint8_t shard_level) const override; + base::expected UpdateCheckpoint( + uint32_t id, + OrchardCheckpoint checkpoint) override; + + private: + mojom::AccountIdPtr account_id_; + scoped_refptr storage_; +}; + +} // namespace brave_wallet + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ORCHARD_SHARD_TREE_DELEGATE_IMPL_H_ diff --git a/components/brave_wallet/browser/zcash/orchard_shard_tree_unittest.cc b/components/brave_wallet/browser/zcash/orchard_shard_tree_unittest.cc new file mode 100644 index 000000000000..8699fb8b6aa8 --- /dev/null +++ b/components/brave_wallet/browser/zcash/orchard_shard_tree_unittest.cc @@ -0,0 +1,514 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "base/files/scoped_temp_dir.h" +#include "brave/components/brave_wallet/browser/internal/orchard_shard_tree_manager.h" +#include "brave/components/brave_wallet/browser/internal/orchard_test_utils.h" +#include "brave/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.h" +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h" +#include "brave/components/brave_wallet/common/common_utils.h" +#include "brave/components/brave_wallet/common/hex_utils.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" +#include "content/public/test/browser_task_environment.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace brave_wallet { + +namespace { + +constexpr uint32_t kDefaultCommitmentSeed = 1; + +OrchardNoteWitness CreateWitness(const std::vector& path, + uint32_t position) { + OrchardNoteWitness result; + for (const auto& path_elem : path) { + OrchardMerkleHash as_bytes; + EXPECT_TRUE(base::HexStringToSpan(path_elem, as_bytes)); + result.merkle_path.push_back(as_bytes); + } + result.position = position; + return result; +} + +OrchardCommitment CreateCommitment(OrchardCommitmentValue value, + bool marked, + std::optional checkpoint_id) { + return OrchardCommitment{value, marked, checkpoint_id}; +} + +} // namespace + +class OrchardShardTreeTest : public testing::Test { + public: + OrchardShardTreeTest() + : task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME) {} + void SetUp() override; + + OrchardShardTreeManager* tree_manager() { return shard_tree_manager_.get(); } + + OrchardTestUtils* test_utils() { return orchard_test_utils_.get(); } + + ZCashOrchardStorage* storage() { return storage_.get(); } + + mojom::AccountIdPtr account_id() { return account_id_.Clone(); } + + private: + base::test::TaskEnvironment task_environment_; + base::ScopedTempDir temp_dir_; + mojom::AccountIdPtr account_id_; + + scoped_refptr storage_; + std::unique_ptr shard_tree_manager_; + std::unique_ptr orchard_test_utils_; +}; + +void OrchardShardTreeTest::SetUp() { + account_id_ = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); + base::FilePath db_path( + temp_dir_.GetPath().Append(FILE_PATH_LITERAL("orchard.db"))); + storage_ = base::WrapRefCounted(new ZCashOrchardStorage(db_path)); + shard_tree_manager_ = OrchardShardTreeManager::CreateForTesting( + std::make_unique(account_id_.Clone(), + storage_)); + orchard_test_utils_ = std::make_unique(); +} + +TEST_F(OrchardShardTreeTest, CheckpointsPruned) { + std::vector commitments; + + for (int i = 0; i < 40; i++) { + std::optional checkpoint; + if (i % 2 == 0) { + checkpoint = i * 2; + } + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, checkpoint)); + } + OrchardTreeState orchard_tree_state; + auto result = OrchardBlockScanner::CreateResultForTesting(orchard_tree_state, + commitments); + + tree_manager()->InsertCommitments(std::move(result)); + + EXPECT_EQ(10u, storage()->CheckpointCount(account_id()).value()); + EXPECT_EQ(40u, storage()->MinCheckpointId(account_id()).value().value()); + EXPECT_EQ(76u, storage()->MaxCheckpointId(account_id()).value().value()); +} + +TEST_F(OrchardShardTreeTest, InsertWithFrontier) { + OrchardTreeState prior_tree_state; + prior_tree_state.block_height = 0; + prior_tree_state.tree_size = 48; + prior_tree_state.frontier = std::vector( + {1, 72, 173, 200, 225, 47, 142, 44, 148, 137, 119, 18, 99, 211, + 92, 65, 67, 173, 197, 93, 7, 85, 70, 105, 140, 223, 184, 193, + 172, 9, 194, 88, 62, 1, 130, 31, 76, 59, 69, 55, 151, 124, + 101, 120, 230, 247, 201, 82, 48, 160, 150, 48, 23, 84, 250, 117, + 120, 175, 108, 220, 96, 214, 42, 255, 209, 44, 7, 1, 13, 59, + 69, 136, 45, 180, 148, 18, 146, 125, 241, 196, 224, 205, 11, 196, + 195, 90, 164, 186, 175, 22, 90, 105, 82, 149, 34, 131, 232, 132, + 223, 15, 1, 211, 200, 193, 46, 24, 11, 42, 42, 182, 124, 29, + 48, 234, 215, 28, 103, 218, 239, 234, 109, 10, 231, 74, 70, 197, + 113, 131, 89, 199, 71, 102, 33, 1, 153, 86, 62, 213, 2, 98, + 191, 65, 218, 123, 73, 155, 243, 225, 45, 10, 241, 132, 49, 33, + 101, 183, 59, 35, 56, 78, 228, 47, 166, 10, 237, 50, 0, 1, + 94, 228, 186, 123, 0, 136, 39, 192, 226, 129, 40, 253, 0, 83, + 248, 138, 7, 26, 120, 212, 191, 135, 44, 0, 171, 42, 69, 6, + 133, 205, 115, 4, 0, 0}); + + std::vector commitments; + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(48, kDefaultCommitmentSeed), + false, std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(49, kDefaultCommitmentSeed), + false, std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(50, kDefaultCommitmentSeed), true, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(51, kDefaultCommitmentSeed), + false, 1)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(52, kDefaultCommitmentSeed), + false, std::nullopt)); + + auto result = OrchardBlockScanner::CreateResultForTesting(prior_tree_state, + commitments); + tree_manager()->InsertCommitments(std::move(result)); + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 50; + + auto witness_result = tree_manager()->CalculateWitness({input}, 1); + EXPECT_TRUE(witness_result.has_value()); + EXPECT_EQ( + witness_result.value()[0].witness.value(), + CreateWitness( + {"9695d64b1ccd38aa5dfdc5c70aecf0e763549034318c59943a3e3e921b415c3a", + "48ddf8a84afc5949e074c162630e3f6aab3d4350bf929ba82677cee4c634e029", + "c7413f4614cd64043abbab7cc1095c9bb104231cea89e2c3e0df83769556d030", + "2111fc397753e5fd50ec74816df27d6ada7ed2a9ac3816aab2573c8fac794204", + "2d99471d096691e4a5f43efe469734aff37f4f21c707b060c952a84169f9302f", + "5ee4ba7b008827c0e28128fd0053f88a071a78d4bf872c00ab2a450685cd7304", + "27ab1320953ae1ad70c8c15a1253a0a86fbc8a0aa36a84207293f8a495ffc402", + "4e14563df191a2a65b4b37113b5230680555051b22d74a8e1f1d706f90f3133" + "b"}, + 50)); + } +} + +TEST_F(OrchardShardTreeTest, Checkpoint_WithMarked) { + std::vector commitments; + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(0, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(1, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(2, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(3, kDefaultCommitmentSeed), true, + 1)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(4, kDefaultCommitmentSeed), false, + std::nullopt)); + + OrchardTreeState tree_state; + auto result = + OrchardBlockScanner::CreateResultForTesting(tree_state, commitments); + tree_manager()->InsertCommitments(std::move(result)); + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 3; + auto witness_result = tree_manager()->CalculateWitness({input}, 1); + EXPECT_TRUE(witness_result.has_value()); + + EXPECT_EQ( + witness_result.value()[0].witness.value(), + CreateWitness( + {"3bb11bd05d2ed5e590369f274a1a247d390380aa0590160bfbf72cb186d7023f", + "d4059d13ddcbe9ec7e6fc99bdf9bfd08b0a678d26e3bf6a734e7688eca669f37", + "c7413f4614cd64043abbab7cc1095c9bb104231cea89e2c3e0df83769556d030", + "2111fc397753e5fd50ec74816df27d6ada7ed2a9ac3816aab2573c8fac794204", + "806afbfeb45c64d4f2384c51eff30764b84599ae56a7ab3d4a46d9ce3aeab431", + "873e4157f2c0f0c645e899360069fcc9d2ed9bc11bf59827af0230ed52edab18", + "27ab1320953ae1ad70c8c15a1253a0a86fbc8a0aa36a84207293f8a495ffc402", + "4e14563df191a2a65b4b37113b5230680555051b22d74a8e1f1d706f90f3133" + "b"}, + 3)); + } +} + +TEST_F(OrchardShardTreeTest, MinCheckpoint) { + std::vector commitments; + + for (int i = 0; i < 40; i++) { + std::optional checkpoint; + if (i % 2 == 0) { + checkpoint = i * 2; + } + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, checkpoint)); + } + OrchardTreeState tree_state; + auto result = + OrchardBlockScanner::CreateResultForTesting(tree_state, commitments); + tree_manager()->InsertCommitments(std::move(result)); + + EXPECT_EQ(10u, storage()->CheckpointCount(account_id()).value()); + EXPECT_EQ(40u, storage()->MinCheckpointId(account_id()).value().value()); + EXPECT_EQ(76u, storage()->MaxCheckpointId(account_id()).value().value()); +} + +TEST_F(OrchardShardTreeTest, MaxCheckpoint) { + { + std::vector commitments; + + for (int i = 0; i < 5; i++) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, std::nullopt)); + } + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(5, kDefaultCommitmentSeed), + false, 1u)); + OrchardTreeState tree_state; + auto result = + OrchardBlockScanner::CreateResultForTesting(tree_state, commitments); + tree_manager()->InsertCommitments(std::move(result)); + } + + { + std::vector commitments; + + for (int i = 6; i < 10; i++) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, std::nullopt)); + } + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(10, kDefaultCommitmentSeed), + false, 2u)); + OrchardTreeState tree_state; + tree_state.block_height = 1; + tree_state.tree_size = 6; + auto result = + OrchardBlockScanner::CreateResultForTesting(tree_state, commitments); + tree_manager()->InsertCommitments(std::move(result)); + } + + { + std::vector commitments; + + for (int i = 11; i < 15; i++) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, std::nullopt)); + } + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(15, kDefaultCommitmentSeed), + false, 3u)); + OrchardTreeState tree_state; + tree_state.block_height = 2; + tree_state.tree_size = 11; + auto result = + OrchardBlockScanner::CreateResultForTesting(tree_state, commitments); + tree_manager()->InsertCommitments(std::move(result)); + } + + EXPECT_EQ(3u, storage()->CheckpointCount(account_id()).value()); + EXPECT_EQ(1u, storage()->MinCheckpointId(account_id()).value().value()); + EXPECT_EQ(3u, storage()->MaxCheckpointId(account_id()).value().value()); +} + +TEST_F(OrchardShardTreeTest, NoWitnessOnNonMarked) { + std::vector commitments; + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(0, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(1, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(2, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(3, kDefaultCommitmentSeed), false, + 1)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(4, kDefaultCommitmentSeed), false, + std::nullopt)); + + auto result = OrchardBlockScanner::CreateResultForTesting(OrchardTreeState(), + commitments); + tree_manager()->InsertCommitments(std::move(result)); + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 2; + auto witness_result = tree_manager()->CalculateWitness({input}, 1); + EXPECT_FALSE(witness_result.has_value()); + } +} + +TEST_F(OrchardShardTreeTest, NoWitnessOnWrongCheckpoint) { + std::vector commitments; + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(0, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(1, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(2, kDefaultCommitmentSeed), true, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(3, kDefaultCommitmentSeed), false, + 1)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(4, kDefaultCommitmentSeed), false, + std::nullopt)); + + auto result = OrchardBlockScanner::CreateResultForTesting(OrchardTreeState(), + commitments); + tree_manager()->InsertCommitments(std::move(result)); + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 2; + auto witness_result = tree_manager()->CalculateWitness({input}, 2); + EXPECT_FALSE(witness_result.has_value()); + } +} + +TEST_F(OrchardShardTreeTest, TruncateTree) { + { + std::vector commitments; + + for (int i = 0; i < 10; i++) { + if (i == 2) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + true, std::nullopt)); + } else if (i == 3) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, 1)); + } else if (i == 5) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, 2)); + } else { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(i, kDefaultCommitmentSeed), + false, std::nullopt)); + } + } + + auto result = OrchardBlockScanner::CreateResultForTesting( + OrchardTreeState(), commitments); + tree_manager()->InsertCommitments(std::move(result)); + } + + tree_manager()->Truncate(2); + + { + std::vector commitments; + + for (int j = 0; j < 5; j++) { + if (j == 3) { + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(j, 5), false, 2)); + } else { + commitments.push_back( + CreateCommitment(test_utils()->CreateMockCommitmentValue(j, 5), + false, std::nullopt)); + } + } + + OrchardTreeState tree_state; + tree_state.block_height = 1; + // Truncate was on position 5, so 5 elements left in the tre + tree_state.tree_size = 5; + auto result = + OrchardBlockScanner::CreateResultForTesting(tree_state, commitments); + tree_manager()->InsertCommitments(std::move(result)); + } + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 2; + auto witness_result = tree_manager()->CalculateWitness({input}, 2); + EXPECT_TRUE(witness_result.has_value()); + } + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 2; + auto witness_result = tree_manager()->CalculateWitness({input}, 1); + EXPECT_TRUE(witness_result.has_value()); + EXPECT_EQ( + witness_result.value()[0].witness.value(), + CreateWitness( + {"f342eb6489f4e5b5a0fb0a4ece48d137dcd5e80011aab4668913f98be2af3311", + "d4059d13ddcbe9ec7e6fc99bdf9bfd08b0a678d26e3bf6a734e7688eca669f37", + "c7413f4614cd64043abbab7cc1095c9bb104231cea89e2c3e0df83769556d030", + "2111fc397753e5fd50ec74816df27d6ada7ed2a9ac3816aab2573c8fac794204", + "806afbfeb45c64d4f2384c51eff30764b84599ae56a7ab3d4a46d9ce3aeab431", + "873e4157f2c0f0c645e899360069fcc9d2ed9bc11bf59827af0230ed52edab18", + "27ab1320953ae1ad70c8c15a1253a0a86fbc8a0aa36a84207293f8a495ffc402", + "4e14563df191a2a65b4b37113b5230680555051b22d74a8e1f1d706f90f3133" + "b"}, + 2)); + } +} + +TEST_F(OrchardShardTreeTest, TruncateTreeWrongCheckpoint) { + std::vector commitments; + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(0, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(1, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(2, kDefaultCommitmentSeed), true, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(3, kDefaultCommitmentSeed), false, + 1)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(4, kDefaultCommitmentSeed), false, + std::nullopt)); + + auto result = OrchardBlockScanner::CreateResultForTesting(OrchardTreeState(), + commitments); + tree_manager()->InsertCommitments(std::move(result)); + + EXPECT_FALSE(tree_manager()->Truncate(2)); +} + +TEST_F(OrchardShardTreeTest, SimpleInsert) { + std::vector commitments; + + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(0, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(1, kDefaultCommitmentSeed), false, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(2, kDefaultCommitmentSeed), true, + std::nullopt)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(3, kDefaultCommitmentSeed), false, + 1)); + commitments.push_back(CreateCommitment( + test_utils()->CreateMockCommitmentValue(4, kDefaultCommitmentSeed), false, + std::nullopt)); + + auto result = OrchardBlockScanner::CreateResultForTesting(OrchardTreeState(), + commitments); + tree_manager()->InsertCommitments(std::move(result)); + + { + OrchardInput input; + input.note.orchard_commitment_tree_position = 2; + auto witness_result = tree_manager()->CalculateWitness({input}, 1); + EXPECT_TRUE(witness_result.has_value()); + EXPECT_EQ( + witness_result.value()[0].witness.value(), + CreateWitness( + {"f342eb6489f4e5b5a0fb0a4ece48d137dcd5e80011aab4668913f98be2af3311", + "d4059d13ddcbe9ec7e6fc99bdf9bfd08b0a678d26e3bf6a734e7688eca669f37", + "c7413f4614cd64043abbab7cc1095c9bb104231cea89e2c3e0df83769556d030", + "2111fc397753e5fd50ec74816df27d6ada7ed2a9ac3816aab2573c8fac794204", + "806afbfeb45c64d4f2384c51eff30764b84599ae56a7ab3d4a46d9ce3aeab431", + "873e4157f2c0f0c645e899360069fcc9d2ed9bc11bf59827af0230ed52edab18", + "27ab1320953ae1ad70c8c15a1253a0a86fbc8a0aa36a84207293f8a495ffc402", + "4e14563df191a2a65b4b37113b5230680555051b22d74a8e1f1d706f90f3133" + "b"}, + 2)); + } +} + +} // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/rust/BUILD.gn b/components/brave_wallet/browser/zcash/rust/BUILD.gn index 673e5d05dfd6..2c39869f1d87 100644 --- a/components/brave_wallet/browser/zcash/rust/BUILD.gn +++ b/components/brave_wallet/browser/zcash/rust/BUILD.gn @@ -24,6 +24,9 @@ source_set("orchard_headers") { "authorized_orchard_bundle.h", "extended_spending_key.h", "orchard_block_decoder.h", + "orchard_decoded_blocks_bunde.h", + "orchard_shard_tree.h", + "orchard_test_utils.h", "unauthorized_orchard_bundle.h", ] @@ -44,6 +47,16 @@ source_set("orchard_impl") { "extended_spending_key_impl.h", "orchard_block_decoder_impl.cc", "orchard_block_decoder_impl.h", + "orchard_decoded_blocks_bunde_impl.cc", + "orchard_decoded_blocks_bunde_impl.h", + "orchard_shard_tree_impl.cc", + "orchard_shard_tree_impl.h", + + # TODO(cypt4) : Extract to rust tests target + "orchard_test_utils_impl.cc", + "orchard_test_utils_impl.h", + "orchard_testing_shard_tree_impl.cc", + "orchard_testing_shard_tree_impl.h", "unauthorized_orchard_bundle_impl.cc", "unauthorized_orchard_bundle_impl.h", ] @@ -57,6 +70,17 @@ source_set("orchard_impl") { ] } +source_set("shard_store") { + visibility = [ ":*" ] + sources = [ "cxx/src/shard_store.h" ] + + public_deps = [ + "//base", + "//brave/components/brave_wallet/common", + "//build/rust:cxx_cppdeps", + ] +} + rust_static_library("rust_lib") { visibility = [ ":orchard_impl" ] @@ -69,14 +93,18 @@ rust_static_library("rust_lib") { deps = [ "librustzcash:zcash_client_backend", "librustzcash:zcash_primitives", + "librustzcash:zcash_protocol", "//brave/components/brave_wallet/rust:rust_lib", "//brave/third_party/rust/incrementalmerkletree/v0_5:lib", "//brave/third_party/rust/memuse/v0_2:lib", "//brave/third_party/rust/nonempty/v0_7:lib", "//brave/third_party/rust/orchard/v0_8:lib", + "//brave/third_party/rust/pasta_curves/v0_5:lib", "//brave/third_party/rust/rand/v0_8:lib", "//brave/third_party/rust/shardtree/v0_3:lib", "//brave/third_party/rust/zcash_note_encryption/v0_4:lib", "//third_party/rust/byteorder/v1:lib", ] + + public_deps = [ ":shard_store" ] } diff --git a/components/brave_wallet/browser/zcash/rust/DEPS b/components/brave_wallet/browser/zcash/rust/DEPS index b6c9bddb9ffa..aeaebf34a4c8 100644 --- a/components/brave_wallet/browser/zcash/rust/DEPS +++ b/components/brave_wallet/browser/zcash/rust/DEPS @@ -8,5 +8,8 @@ specific_include_rules = { "unauthorized_orchard_bundle_impl.h": [ "+third_party/rust/cxx/v1/cxx.h", ], + "orchard_decoded_blocks_bunde_impl.h": [ + "+third_party/rust/cxx/v1/cxx.h", + ] } diff --git a/components/brave_wallet/browser/zcash/rust/cxx/src/DEPS b/components/brave_wallet/browser/zcash/rust/cxx/src/DEPS new file mode 100644 index 000000000000..db15a8ff731d --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/cxx/src/DEPS @@ -0,0 +1,6 @@ +specific_include_rules = { + "shard_store.h": [ + "+third_party/rust/cxx/v1/cxx.h", + ], +} + diff --git a/components/brave_wallet/browser/zcash/rust/cxx/src/shard_store.h b/components/brave_wallet/browser/zcash/rust/cxx/src/shard_store.h new file mode 100644 index 000000000000..538d4be28cc1 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/cxx/src/shard_store.h @@ -0,0 +1,72 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_CXX_SRC_SHARD_STORE_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_CXX_SRC_SHARD_STORE_H_ + +#include "brave/components/brave_wallet/common/zcash_utils.h" +#include "third_party/rust/cxx/v1/cxx.h" + +namespace brave_wallet::orchard { + +enum class ShardStoreStatusCode : uint32_t; +struct ShardTreeShard; +struct ShardTreeAddress; +struct ShardTreeCheckpoint; +struct ShardTreeCap; +struct ShardTreeCheckpointBundle; + +using ShardStoreContext = ::brave_wallet::OrchardShardTreeDelegate; + +ShardStoreStatusCode shard_store_last_shard(const ShardStoreContext& ctx, + ShardTreeShard& into); +ShardStoreStatusCode shard_store_put_shard(ShardStoreContext& ctx, + const ShardTreeShard& tree); +ShardStoreStatusCode shard_store_get_shard(const ShardStoreContext& ctx, + const ShardTreeAddress& addr, + ShardTreeShard& tree); +ShardStoreStatusCode shard_store_get_shard_roots( + const ShardStoreContext& ctx, + ::rust::Vec& into); +ShardStoreStatusCode shard_store_truncate(ShardStoreContext& ctx, + const ShardTreeAddress& address); +ShardStoreStatusCode shard_store_get_cap(const ShardStoreContext& ctx, + ShardTreeCap& into); +ShardStoreStatusCode shard_store_put_cap(ShardStoreContext& ctx, + const ShardTreeCap& tree); +ShardStoreStatusCode shard_store_min_checkpoint_id(const ShardStoreContext& ctx, + uint32_t& into); +ShardStoreStatusCode shard_store_max_checkpoint_id(const ShardStoreContext& ctx, + uint32_t& into); +ShardStoreStatusCode shard_store_add_checkpoint( + ShardStoreContext& ctx, + uint32_t checkpoint_id, + const ShardTreeCheckpoint& checkpoint); +ShardStoreStatusCode shard_store_checkpoint_count(const ShardStoreContext& ctx, + size_t& into); +ShardStoreStatusCode shard_store_get_checkpoint_at_depth( + const ShardStoreContext& ctx, + size_t depth, + uint32_t& into_checkpoint_id, + ShardTreeCheckpoint& into_checpoint); +ShardStoreStatusCode shard_store_get_checkpoint(const ShardStoreContext& ctx, + uint32_t checkpoint_id, + ShardTreeCheckpoint& into); +ShardStoreStatusCode shard_store_update_checkpoint( + ShardStoreContext& ctx, + uint32_t checkpoint_id, + const ShardTreeCheckpoint& checkpoint); +ShardStoreStatusCode shard_store_remove_checkpoint(ShardStoreContext& ctx, + uint32_t checkpoint_id); +ShardStoreStatusCode shard_store_truncate_checkpoint(ShardStoreContext& ctx, + uint32_t checkpoint_id); +ShardStoreStatusCode shard_store_get_checkpoints( + const ShardStoreContext& ctx, + size_t limit, + ::rust::Vec& into); + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_CXX_SRC_SHARD_STORE_H_ diff --git a/components/brave_wallet/browser/zcash/rust/lib.rs b/components/brave_wallet/browser/zcash/rust/lib.rs index 4c34c3ad2f4e..bf48e8888935 100644 --- a/components/brave_wallet/browser/zcash/rust/lib.rs +++ b/components/brave_wallet/browser/zcash/rust/lib.rs @@ -3,52 +3,100 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -use std::fmt; +use std::{ + cell::RefCell, cmp::{Ord, Ordering}, + collections::BTreeSet, convert::TryFrom, error, fmt, io::Cursor, marker::PhantomData, + ops::{Add, Bound, RangeBounds, Sub}, rc::Rc, vec}; use orchard::{ builder:: { - BuildError as OrchardBuildError, - InProgress, - Unproven, - Unauthorized - }, - bundle::Bundle, - zip32::ChildIndex as OrchardChildIndex, - keys::Scope as OrchardScope, - keys::FullViewingKey as OrchardFVK, - keys::PreparedIncomingViewingKey, - zip32::Error as Zip32Error, - zip32::ExtendedSpendingKey, - tree::MerkleHashOrchard, - note_encryption:: { - OrchardDomain, - CompactAction - }, + BuildError as OrchardBuildError, InProgress, Unauthorized, Unproven + }, bundle::{commitments, Bundle}, + keys::{FullViewingKey as OrchardFVK, + PreparedIncomingViewingKey, + Scope as OrchardScope, SpendingKey}, note:: { - Nullifier, - ExtractedNoteCommitment - } + ExtractedNoteCommitment, Nullifier, RandomSeed, Rho + }, note_encryption:: { + CompactAction, OrchardDomain + }, keys::SpendAuthorizingKey, + tree::{MerkleHashOrchard, MerklePath}, + value::NoteValue, + zip32::{ + ChildIndex as OrchardChildIndex, + Error as Zip32Error, + ExtendedSpendingKey}, + Anchor }; use zcash_note_encryption::EphemeralKeyBytes; -use zcash_primitives::transaction::components::amount::Amount; - -use ffi::OrchardOutput; +use zcash_protocol::consensus::BlockHeight; +use zcash_primitives::{ + merkle_tree::{read_commitment_tree, HashSer}, + transaction::components::amount::Amount}; -use rand::rngs::OsRng; -use rand::{RngCore, Error as OtherError}; -use rand::CryptoRng; +use incrementalmerkletree::{ + frontier::{self, Frontier}, + Address, + Position, + Retention}; -use brave_wallet::{ - impl_error -}; +use rand::{rngs::OsRng, CryptoRng, Error as OtherError, RngCore}; +use brave_wallet::impl_error; +use std::sync::Arc; use zcash_note_encryption::{ batch, Domain, ShieldedOutput, COMPACT_NOTE_SIZE, }; +use shardtree::{ + error::ShardTreeError, + store::{Checkpoint, ShardStore, TreeState}, + LocatedPrunableTree, LocatedTree, PrunableTree, RetentionFlags, + ShardTree, +}; + +use zcash_client_backend::serialization::shardtree::{read_shard, write_shard}; +use cxx::UniquePtr; + +use pasta_curves::{group::ff::Field, pallas}; + +use crate::ffi::OrchardSpend; +use crate::ffi::OrchardOutput; use crate::ffi::OrchardCompactAction; +use crate::ffi::ShardTreeShard; +use crate::ffi::ShardTreeAddress; +use crate::ffi::ShardTreeCheckpoint; +use crate::ffi::ShardTreeCap; +use crate::ffi::ShardTreeCheckpointRetention; +use crate::ffi::ShardTreeState; +use crate::ffi::ShardStoreStatusCode; +use crate::ffi::ShardTreeLeaf; +use crate::ffi::ShardTreeLeafs; +use crate::ffi::ShardStoreContext; +use crate::ffi::ShardTreeCheckpointBundle; + +use crate::ffi::{ + shard_store_add_checkpoint, + shard_store_checkpoint_count, + shard_store_get_cap, + shard_store_get_checkpoint, + shard_store_get_checkpoint_at_depth, + shard_store_get_checkpoints, + shard_store_get_shard, + shard_store_get_shard_roots, + shard_store_last_shard, + shard_store_max_checkpoint_id, + shard_store_min_checkpoint_id, + shard_store_put_cap, + shard_store_put_shard, + shard_store_remove_checkpoint, + shard_store_truncate, + shard_store_truncate_checkpoint, + shard_store_update_checkpoint}; +use shardtree::error::QueryError; + // The rest of the wallet code should be updated to use this version of unwrap // and then this code can be removed #[macro_export] @@ -90,6 +138,16 @@ macro_rules! impl_result { }; } +pub(crate) const PRUNING_DEPTH: u8 = 100; +pub(crate) const SHARD_HEIGHT: u8 = 16; +pub(crate) const TREE_HEIGHT: u8 = 32; +pub(crate) const CHUNK_SIZE: usize = 1024; + +pub(crate) const TESTING_PRUNING_DEPTH: u8 = 10; +pub(crate) const TESTING_SHARD_HEIGHT: u8 = 4; +pub(crate) const TESTING_TREE_HEIGHT: u8 = 8; +pub(crate) const TESTING_CHUNK_SIZE: usize = 16; + #[derive(Clone)] pub(crate) struct MockRng(u64); @@ -133,6 +191,7 @@ impl RngCore for MockRng { } +#[allow(unused)] #[allow(unsafe_op_in_unsafe_fn)] #[cxx::bridge(namespace = brave_wallet::orchard)] mod ffi { @@ -147,30 +206,130 @@ mod ffi { use_memo: bool } + struct MerkleHash { + hash: [u8; 32] + } + + struct MerklePath { + position: u32, + auth_path: Vec, + root: MerkleHash + } + + struct OrchardSpend { + fvk: [u8; 96], + sk: [u8; 32], + // Note value + value: u32, + addr: [u8; 43], + rho: [u8; 32], + r: [u8; 32], + // Witness merkle path + merkle_path: MerklePath + } + // Encoded orchard output extracted from the transaction struct OrchardCompactAction { nullifier: [u8; 32], // kOrchardNullifierSize ephemeral_key: [u8; 32], // kOrchardEphemeralKeySize cmx: [u8; 32], // kOrchardCmxSize - enc_cipher_text : [u8; 52] // kOrchardCipherTextSize + enc_cipher_text : [u8; 52], // kOrchardCipherTextSize + block_id: u32, + is_block_last_action: bool + } + + // Represents information about tree state at the end of the block prior to the scan range + #[derive(Clone)] + struct ShardTreeState { + // Frontier is a compressed representation of merkle tree state at some leaf position + // It allows to compute merkle path to the next leafs without storing all the tree + // May be empty if no frontier inserted(In case of append) + frontier: Vec, + // Block height of the previous block prior to the scan range + // The height of the block + block_height: u32, + // Tree size of the tree at the end of the prior block, used to calculate leafs indexes + tree_size: u32 + } + + #[derive(Clone)] + struct ShardTreeCheckpointRetention { + checkpoint: bool, + marked: bool, + checkpoint_id: u32 + } + + #[derive(Clone)] + struct ShardTreeLeaf { + hash: [u8; 32], + retention: ShardTreeCheckpointRetention + } + + #[derive(Clone)] + struct ShardTreeLeafs { + commitments: Vec + } + + #[repr(u32)] + enum ShardStoreStatusCode { + Ok = 0, + None = 1, + Error = 2 + } + + #[derive(Default)] + struct ShardTreeAddress { + level: u8, + index: u32 + } + + #[derive(Default)] + struct ShardTreeCap { + data: Vec, + } + + #[derive(Default)] + struct ShardTreeShard { + address: ShardTreeAddress, + // Maybe empty on uncompleted shards + hash: Vec, + data: Vec + } + + #[derive(Default)] + struct ShardTreeCheckpoint { + empty: bool, + position: u32, + mark_removed: Vec + } + + #[derive(Default)] + struct ShardTreeCheckpointBundle { + checkpoint_id: u32, + checkpoint: ShardTreeCheckpoint } extern "Rust" { type OrchardExtendedSpendingKey; type OrchardUnauthorizedBundle; type OrchardAuthorizedBundle; - type BatchOrchardDecodeBundle; + type OrchardShardTreeBundle; + type OrchardTestingShardTreeBundle; + type OrchardWitnessBundle; type OrchardExtendedSpendingKeyResult; type OrchardUnauthorizedBundleResult; type OrchardAuthorizedBundleResult; - + type OrchardWitnessBundleResult; type BatchOrchardDecodeBundleResult; + type OrchardTestingShardTreeBundleResult; + type OrchardShardTreeBundleResult; // OsRng is used fn create_orchard_bundle( tree_state: &[u8], + spends: Vec, outputs: Vec ) -> Box; @@ -178,6 +337,7 @@ mod ffi { // Must not be used in production, only in tests. fn create_testing_orchard_bundle( tree_state: &[u8], + spends: Vec, outputs: Vec, rng_seed: u64 ) -> Box; @@ -192,6 +352,7 @@ mod ffi { fn batch_decode( fvk_bytes: &[u8; 96], // Array size should match kOrchardFullViewKeySize + prior_tree_state: ShardTreeState, actions: Vec ) -> Box; @@ -215,6 +376,10 @@ mod ffi { self: &OrchardExtendedSpendingKey ) -> [u8; 96]; // Array size sohuld match kOrchardFullViewKeySize + fn spending_key( + self: &OrchardExtendedSpendingKey + ) -> [u8; 32]; // Array size should match kSpendingKeySize + fn is_ok(self: &OrchardAuthorizedBundleResult) -> bool; fn error_message(self: &OrchardAuthorizedBundleResult) -> String; fn unwrap(self: &OrchardAuthorizedBundleResult) -> Box; @@ -231,7 +396,20 @@ mod ffi { fn note_value(self :&BatchOrchardDecodeBundle, index: usize) -> u32; // Result array size should match kOrchardNullifierSize // fvk array size should match kOrchardFullViewKeySize - fn note_nullifier(self :&BatchOrchardDecodeBundle, fvk: &[u8; 96], index: usize) -> [u8; 32]; + fn note_nullifier(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 32]; + fn note_rho(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 32]; + fn note_rseed(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 32]; + fn note_addr(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 43]; + fn note_block_height(self :&BatchOrchardDecodeBundle, index: usize) -> u32; + fn note_commitment_tree_position(self :&BatchOrchardDecodeBundle, index: usize) -> u32; + + fn is_ok(self: &OrchardShardTreeBundleResult) -> bool; + fn error_message(self: &OrchardShardTreeBundleResult) -> String; + fn unwrap(self: &OrchardShardTreeBundleResult) -> Box; + + fn is_ok(self: &OrchardTestingShardTreeBundleResult) -> bool; + fn error_message(self: &OrchardTestingShardTreeBundleResult) -> String; + fn unwrap(self: &OrchardTestingShardTreeBundleResult) -> Box; // Orchard digest is desribed here https://zips.z.cash/zip-0244#t-4-orchard-digest // Used in constructing signature digest and tx id @@ -243,17 +421,124 @@ mod ffi { // Orchard part of v5 transaction as described in // https://zips.z.cash/zip-0225 fn raw_tx(self: &OrchardAuthorizedBundle) -> Vec; + + // Witness is used to construct zk-proof for the transaction + fn is_ok(self: &OrchardWitnessBundleResult) -> bool; + fn error_message(self: &OrchardWitnessBundleResult) -> String; + fn unwrap(self: &OrchardWitnessBundleResult) -> Box; + fn size(self :&OrchardWitnessBundle) -> usize; + fn item(self: &OrchardWitnessBundle, index: usize) -> [u8; 32]; + + // Creates shard tree of default orchard height + fn create_shard_tree( + ctx: UniquePtr + ) -> Box; + // Creates shard tree of smaller size for testing purposes + fn create_testing_shard_tree( + ctx: UniquePtr + ) -> Box; + + fn insert_commitments( + self: &mut OrchardShardTreeBundle, + scan_result: &mut BatchOrchardDecodeBundle) -> bool; + fn calculate_witness( + self: &mut OrchardShardTreeBundle, + commitment_tree_position: u32, + checkpoint: u32) -> Box; + fn truncate(self: &mut OrchardShardTreeBundle, checkpoint_id: u32) -> bool; + + fn insert_commitments( + self: &mut OrchardTestingShardTreeBundle, + scan_result: &mut BatchOrchardDecodeBundle) -> bool; + fn calculate_witness( + self: &mut OrchardTestingShardTreeBundle, + commitment_tree_position: u32, + checkpoint: u32) -> Box; + fn truncate(self: &mut OrchardTestingShardTreeBundle, checkpoint_id: u32) -> bool; + + // Size matches kOrchardCmxSize in zcash_utils + fn create_mock_commitment(position: u32, seed: u32) -> [u8; 32]; + fn create_mock_decode_result( + prior_tree_state: ShardTreeState, + commitments: ShardTreeLeafs) -> Box; + } + + unsafe extern "C++" { + include!("brave/components/brave_wallet/browser/zcash/rust/cxx/src/shard_store.h"); + + type ShardStoreContext; + + fn shard_store_last_shard( + ctx: &ShardStoreContext, into: &mut ShardTreeShard) -> ShardStoreStatusCode; + fn shard_store_get_shard( + ctx: &ShardStoreContext, + addr: &ShardTreeAddress, + tree: &mut ShardTreeShard) -> ShardStoreStatusCode; + fn shard_store_put_shard( + ctx: Pin<&mut ShardStoreContext>, + tree: &ShardTreeShard) -> ShardStoreStatusCode; + fn shard_store_get_shard_roots( + ctx: &ShardStoreContext, into: &mut Vec) -> ShardStoreStatusCode; + fn shard_store_truncate( + ctx: Pin<&mut ShardStoreContext>, + address: &ShardTreeAddress) -> ShardStoreStatusCode; + fn shard_store_get_cap( + ctx: &ShardStoreContext, + into: &mut ShardTreeCap) -> ShardStoreStatusCode; + fn shard_store_put_cap( + ctx: Pin<&mut ShardStoreContext>, + tree: &ShardTreeCap) -> ShardStoreStatusCode; + fn shard_store_min_checkpoint_id( + ctx: &ShardStoreContext, into: &mut u32) -> ShardStoreStatusCode; + fn shard_store_max_checkpoint_id( + ctx: &ShardStoreContext, into: &mut u32) -> ShardStoreStatusCode; + fn shard_store_add_checkpoint( + ctx: Pin<&mut ShardStoreContext>, + checkpoint_id: u32, + checkpoint: &ShardTreeCheckpoint) -> ShardStoreStatusCode; + fn shard_store_update_checkpoint( + ctx: Pin<&mut ShardStoreContext>, + checkpoint_id: u32, + checkpoint: &ShardTreeCheckpoint) -> ShardStoreStatusCode; + fn shard_store_checkpoint_count( + ctx: &ShardStoreContext, + into: &mut usize) -> ShardStoreStatusCode; + fn shard_store_get_checkpoint_at_depth( + ctx: &ShardStoreContext, + depth: usize, + into_checkpoint_id: &mut u32, + into_checkpoint: &mut ShardTreeCheckpoint) -> ShardStoreStatusCode; + fn shard_store_get_checkpoint( + ctx: &ShardStoreContext, + checkpoint_id: u32, + into: &mut ShardTreeCheckpoint) -> ShardStoreStatusCode; + fn shard_store_remove_checkpoint( + ctx: Pin<&mut ShardStoreContext>, + checkpoint_id: u32) -> ShardStoreStatusCode; + fn shard_store_truncate_checkpoint( + ctx: Pin<&mut ShardStoreContext>, + checkpoint_id: u32) -> ShardStoreStatusCode; + fn shard_store_get_checkpoints( + ctx: &ShardStoreContext, + limit: usize, + into: &mut Vec) -> ShardStoreStatusCode; + } + } #[derive(Debug)] pub enum Error { Zip32(Zip32Error), - OrchardBuilder(OrchardBuildError), + WrongInputError, WrongOutputError, BuildError, FvkError, OrchardActionFormatError, + ShardStoreError, + OrchardBuilder(OrchardBuildError), + WitnessError, + SpendError, } impl_error!(Zip32Error, Zip32); @@ -267,11 +552,21 @@ impl fmt::Display for Error { Error::WrongOutputError => write!(f, "Error: Can't parse output"), Error::BuildError => write!(f, "Error, build error"), Error::OrchardActionFormatError => write!(f, "Error, orchard action format error"), - Error::FvkError => write!(f, "Error, fvk format error") + Error::FvkError => write!(f, "Error, fvk format error"), + Error::ShardStoreError => write!(f, "Shard store error"), + Error::WrongInputError => write!(f, "Wrong input error"), + Error::WitnessError => write!(f, "Witness error"), + Error::SpendError => write!(f, "Spend error"), } } } +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + Some(self) + } +} + // Different random sources are used for testing and for release // Since Orchard uses randomness we need to mock it to get // deterministic resuluts in tests. @@ -286,7 +581,8 @@ enum OrchardRandomSource { #[derive(Clone)] pub struct OrchardUnauthorizedBundleValue { unauthorized_bundle: Bundle, Amount>, - rng: OrchardRandomSource + rng: OrchardRandomSource, + asks: Vec } // Authorized bundle is a bundle where inputs are signed with signature digests @@ -298,11 +594,54 @@ pub struct OrchardAuthorizedBundleValue { #[derive(Clone)] pub struct DecryptedOrchardOutput { - note: ::Note + note: ::Note, + block_height: u32, + commitment_tree_position: u32 } + #[derive(Clone)] pub struct BatchOrchardDecodeBundleValue { - outputs: Vec + fvk: [u8; 96], + outputs: Vec, + commitments: Vec<(MerkleHashOrchard, Retention)>, + prior_tree_state: ShardTreeState +} + +pub struct OrchardGenericShardTreeBundleValue { + tree: ShardTree, T, S> +} + +type OrchardShardTreeBundleValue = + OrchardGenericShardTreeBundleValue; +type OrchardTestingShardTreeBundleValue = + OrchardGenericShardTreeBundleValue; + +#[derive(Clone)] +pub struct MarkleHashVec(Vec); + +impl From> for MarkleHashVec { + fn from(item: incrementalmerkletree::MerklePath) -> Self { + let mut result : Vec = vec![]; + for elem in item.path_elems() { + result.push(*elem); + } + MarkleHashVec(result) + } +} + +#[derive(Clone)] +pub struct OrchardWitnessBundleValue { + path: MarkleHashVec +} + +impl +Clone for OrchardGenericShardTreeBundleValue { + fn clone(&self) -> Self { + OrchardGenericShardTreeBundleValue { + tree : ShardTree::new(self.tree.store().clone(), + PRUNING_DEPTH.into()) + } + } } #[derive(Clone)] @@ -313,17 +652,28 @@ struct OrchardAuthorizedBundle(OrchardAuthorizedBundleValue); struct OrchardUnauthorizedBundle(OrchardUnauthorizedBundleValue); #[derive(Clone)] struct BatchOrchardDecodeBundle(BatchOrchardDecodeBundleValue); +#[derive(Clone)] +struct OrchardShardTreeBundle(OrchardShardTreeBundleValue); +#[derive(Clone)] +struct OrchardTestingShardTreeBundle(OrchardTestingShardTreeBundleValue); +#[derive(Clone)] +struct OrchardWitnessBundle(OrchardWitnessBundleValue); struct OrchardExtendedSpendingKeyResult(Result); struct OrchardAuthorizedBundleResult(Result); struct OrchardUnauthorizedBundleResult(Result); struct BatchOrchardDecodeBundleResult(Result); +struct OrchardShardTreeBundleResult(Result); +struct OrchardWitnessBundleResult(Result); +struct OrchardTestingShardTreeBundleResult(Result); impl_result!(OrchardExtendedSpendingKey, OrchardExtendedSpendingKeyResult, ExtendedSpendingKey); impl_result!(OrchardAuthorizedBundle, OrchardAuthorizedBundleResult, OrchardAuthorizedBundleValue); impl_result!(OrchardUnauthorizedBundle, OrchardUnauthorizedBundleResult, OrchardUnauthorizedBundleValue); - impl_result!(BatchOrchardDecodeBundle, BatchOrchardDecodeBundleResult, BatchOrchardDecodeBundleValue); +impl_result!(OrchardShardTreeBundle, OrchardShardTreeBundleResult, OrchardShardTreeBundleValue); +impl_result!(OrchardTestingShardTreeBundle, OrchardTestingShardTreeBundleResult, OrchardTestingShardTreeBundleValue); +impl_result!(OrchardWitnessBundle, OrchardWitnessBundleResult, OrchardWitnessBundleValue); fn generate_orchard_extended_spending_key_from_seed( bytes: &[u8] @@ -367,6 +717,12 @@ impl OrchardExtendedSpendingKey { ) -> [u8; 96] { OrchardFVK::from(&self.0).to_bytes() } + + fn spending_key( + self: &OrchardExtendedSpendingKey + ) -> [u8; 32] { + *self.0.sk().to_bytes() + } } impl OrchardAuthorizedBundle { @@ -377,19 +733,18 @@ impl OrchardAuthorizedBundle { fn create_orchard_builder_internal( orchard_tree_bytes: &[u8], + spends: Vec, outputs: Vec, random_source: OrchardRandomSource ) -> Box { - use orchard::Anchor; - use zcash_primitives::merkle_tree::read_commitment_tree; - // To construct transaction orchard tree state of some block should be provided // But in tests we can use empty anchor. let anchor = if orchard_tree_bytes.len() > 0 { match read_commitment_tree::( &orchard_tree_bytes[..]) { Ok(tree) => Anchor::from(tree.root()), - Err(_e) => return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::from(OrchardBuildError::AnchorMismatch)))), + Err(_e) => return Box::new(OrchardUnauthorizedBundleResult::from( + Err(Error::from(OrchardBuildError::AnchorMismatch)))), } } else { orchard::Anchor::empty_tree() @@ -399,11 +754,73 @@ fn create_orchard_builder_internal( orchard::builder::BundleType::DEFAULT, anchor); + let mut asks: Vec = vec![]; + + for spend in spends { + let fvk = OrchardFVK::from_bytes(&spend.fvk); + if fvk.is_none().into() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::FvkError))) + } + + let auth_path = spend.merkle_path.auth_path.iter().map(|v| { + let hash = MerkleHashOrchard::from_bytes(&v.hash); + if hash.is_some().into() { + Ok(hash.unwrap()) + } else { + Err(Error::WitnessError) + } + }).collect::, _>>(); + + if auth_path.is_err() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::WitnessError))) + } + + let auth_path_sized : Result<[MerkleHashOrchard; orchard::NOTE_COMMITMENT_TREE_DEPTH], _> = auth_path.unwrap().try_into(); + if auth_path_sized.is_err() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::WitnessError))) + } + + let merkle_path = MerklePath::from_parts( + spend.merkle_path.position, + auth_path_sized.unwrap(), + ); + + let rho = Rho::from_bytes(&spend.rho); + if rho.is_none().into() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::OrchardActionFormatError))) + } + let rseed = RandomSeed::from_bytes(spend.r, &rho.unwrap()); + if rseed.is_none().into() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::OrchardActionFormatError))) + } + let addr = orchard::Address::from_raw_address_bytes(&spend.addr); + if addr.is_none().into() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::WrongInputError))) + } + + let note = orchard::Note::from_parts( + addr.unwrap(), + NoteValue::from_raw(u64::from(spend.value)), + rho.unwrap().clone(), + rseed.unwrap()); + + if note.is_none().into() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::OrchardActionFormatError))) + } + + let add_spend_result = builder.add_spend(fvk.unwrap(), note.unwrap(), merkle_path); + if add_spend_result.is_err() { + return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::SpendError))) + } + asks.push(SpendAuthorizingKey::from(&SpendingKey::from_bytes(spend.sk).unwrap())); + } + for out in outputs { let _ = match Option::from(orchard::Address::from_raw_address_bytes(&out.addr)) { Some(addr) => { builder.add_output(None, addr, - orchard::value::NoteValue::from_raw(u64::from(out.value)), if out.use_memo { Some(out.memo)} else { Option::None }) + orchard::value::NoteValue::from_raw( + u64::from(out.value)), if out.use_memo { Some(out.memo)} else { Option::None }) }, None => return Box::new(OrchardUnauthorizedBundleResult::from(Err(Error::WrongOutputError))) }; @@ -416,7 +833,8 @@ fn create_orchard_builder_internal( .and_then(|builder| { builder.map(|bundle| OrchardUnauthorizedBundleValue { unauthorized_bundle: bundle.0, - rng: OrchardRandomSource::OsRng(rng) }).ok_or(Error::BuildError) + rng: OrchardRandomSource::OsRng(rng), + asks: asks }).ok_or(Error::BuildError) }) }, OrchardRandomSource::MockRng(mut rng) => { @@ -425,7 +843,8 @@ fn create_orchard_builder_internal( .and_then(|builder| { builder.map(|bundle| OrchardUnauthorizedBundleValue { unauthorized_bundle: bundle.0, - rng: OrchardRandomSource::MockRng(rng) }).ok_or(Error::BuildError) + rng: OrchardRandomSource::MockRng(rng), asks: asks + }).ok_or(Error::BuildError) }) } })) @@ -433,17 +852,19 @@ fn create_orchard_builder_internal( fn create_orchard_bundle( orchard_tree_bytes: &[u8], + spends: Vec, outputs: Vec ) -> Box { - create_orchard_builder_internal(orchard_tree_bytes, outputs, OrchardRandomSource::OsRng(OsRng)) + create_orchard_builder_internal(orchard_tree_bytes, spends, outputs, OrchardRandomSource::OsRng(OsRng)) } fn create_testing_orchard_bundle( orchard_tree_bytes: &[u8], + spends: Vec, outputs: Vec, rng_seed: u64 ) -> Box { - create_orchard_builder_internal(orchard_tree_bytes, outputs, OrchardRandomSource::MockRng(MockRng(rng_seed))) + create_orchard_builder_internal(orchard_tree_bytes, spends, outputs, OrchardRandomSource::MockRng(MockRng(rng_seed))) } impl OrchardUnauthorizedBundle { @@ -461,7 +882,7 @@ impl OrchardUnauthorizedBundle { b.apply_signatures( &mut rng, sighash, - &[], + &self.0.asks, ) }) }, @@ -472,7 +893,7 @@ impl OrchardUnauthorizedBundle { b.apply_signatures( &mut rng, sighash, - &[], + &self.0.asks, ) }) } @@ -500,6 +921,7 @@ impl ShieldedOutput for OrchardCompactAction { fn batch_decode( fvk_bytes: &[u8; 96], + prior_tree_state: ShardTreeState, actions: Vec ) -> Box { let fvk = match OrchardFVK::from_bytes(fvk_bytes) { @@ -532,7 +954,8 @@ fn batch_decode( let ephemeral_key = EphemeralKeyBytes(v.ephemeral_key); let enc_cipher_text = v.enc_cipher_text; - let compact_action = CompactAction::from_parts(nullifier, cmx, ephemeral_key, enc_cipher_text); + let compact_action = + CompactAction::from_parts(nullifier, cmx, ephemeral_key, enc_cipher_text); let orchard_domain = OrchardDomain::for_compact_action(&compact_action); Ok((orchard_domain, v)) @@ -544,17 +967,57 @@ fn batch_decode( Err(e) => return Box::new(BatchOrchardDecodeBundleResult::from(Err(e.into()))) }; - let decrypted_outputs = batch::try_compact_note_decryption(&ivks, &input_actions.as_slice()) - .into_iter() - .map(|res| { - res.map(|((note, _recipient), _ivk_idx)| DecryptedOrchardOutput { - note: note - }) - }) - .filter_map(|x| x) - .collect::>(); + let mut decrypted_len = 0; + let (decrypted_opts, _decrypted_len) = ( + batch::try_compact_note_decryption(&ivks, &input_actions) + .into_iter() + .map(|v| { + v.map(|((note, _), ivk_idx)| { + decrypted_len += 1; + (ivks[ivk_idx].clone(), note) + }) + }) + .collect::>(), + decrypted_len, + ); + + let mut found_notes: Vec = vec![]; + let mut note_commitments: Vec<(MerkleHashOrchard, Retention)> = vec![]; + + for (output_idx, ((_, output), decrypted_note)) in + input_actions.iter().zip(decrypted_opts).enumerate() { + // If the commitment is the last in the block, ensure that is is retained as a checkpoint + let is_checkpoint = &output.is_block_last_action; + let block_id = &output.block_id; + let retention : Retention = match (decrypted_note.is_some(), is_checkpoint) { + (is_marked, true) => Retention::Checkpoint { + id: BlockHeight::from_u32(*block_id), + is_marked, + }, + (true, false) => Retention::Marked, + (false, false) => Retention::Ephemeral, + }; + let commitment = MerkleHashOrchard::from_bytes(&output.cmx); + if commitment.is_none().into() { + return Box::new(BatchOrchardDecodeBundleResult::from(Err(Error::OrchardActionFormatError))) + } + note_commitments.push((commitment.unwrap(), retention)); + + if let Some((_key_id, note)) = decrypted_note { + found_notes.push(DecryptedOrchardOutput{ + note: note, + block_height: output.block_id, + commitment_tree_position: (output_idx as u32) + prior_tree_state.tree_size + }); + } + } - Box::new(BatchOrchardDecodeBundleResult::from(Ok(BatchOrchardDecodeBundleValue { outputs: decrypted_outputs }))) + Box::new(BatchOrchardDecodeBundleResult::from(Ok(BatchOrchardDecodeBundleValue { + fvk: *fvk_bytes, + outputs: found_notes, + commitments: note_commitments, + prior_tree_state: prior_tree_state + }))) } impl BatchOrchardDecodeBundle { @@ -567,8 +1030,581 @@ impl BatchOrchardDecodeBundle { "Outputs are always created from a u32, so conversion back will always succeed") } - fn note_nullifier(self :&BatchOrchardDecodeBundle, fvk: &[u8; 96], index: usize) -> [u8; 32] { - self.0.outputs[index].note.nullifier(&OrchardFVK::from_bytes(fvk).unwrap()).to_bytes() + fn note_nullifier(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 32] { + self.0.outputs[index].note.nullifier(&OrchardFVK::from_bytes(&self.0.fvk).unwrap()).to_bytes() + } + + fn note_block_height(self :&BatchOrchardDecodeBundle, index: usize) -> u32 { + self.0.outputs[index].block_height + } + + fn note_rho(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 32] { + self.0.outputs[index].note.rho().to_bytes() + + } + + fn note_rseed(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 32] { + *self.0.outputs[index].note.rseed().as_bytes() + } + + fn note_addr(self :&BatchOrchardDecodeBundle, index: usize) -> [u8; 43] { + self.0.outputs[index].note.recipient().to_raw_address_bytes() + } + + fn note_commitment_tree_position(self :&BatchOrchardDecodeBundle, index: usize) -> u32 { + self.0.outputs[index].commitment_tree_position + } +} + +fn insert_frontier( + tree: &mut ShardTree, COMMITMENT_TREE_DEPTH, SHARD_HEIGHT>, + frontier: &Vec +) -> bool { + let frontier_commitment_tree = read_commitment_tree::( + &frontier[..]); + + if frontier_commitment_tree.is_err() { + return false; + } + + let frontier_result = tree.insert_frontier( + frontier_commitment_tree.unwrap().to_frontier(), + Retention::Marked, + ); + + frontier_result.is_ok() +} + +fn insert_commitments( + shard_tree: &mut ShardTree, COMMITMENT_TREE_DEPTH, SHARD_HEIGHT>, + scan_result: &mut BatchOrchardDecodeBundle) -> bool { + let start_position : u64 = scan_result.0.prior_tree_state.tree_size.into(); + + if !scan_result.0.prior_tree_state.frontier.is_empty() { + let frontier_result = insert_frontier::( + shard_tree, &scan_result.0.prior_tree_state.frontier); + if !frontier_result { + return false; + } + } + + let batch_insert_result = shard_tree.batch_insert( + Position::from(start_position), + scan_result.0.commitments.clone().into_iter()); + + if batch_insert_result.is_err() { + return false; + } + + true +} + +impl From<&[MerkleHashOrchard]> for MarkleHashVec { + fn from(item: &[MerkleHashOrchard]) -> Self { + let mut result : Vec = vec![]; + for elem in item { + result.push(*elem); + } + MarkleHashVec(result) + } +} + +impl OrchardShardTreeBundle { + fn insert_commitments(self: &mut OrchardShardTreeBundle, + scan_result: &mut BatchOrchardDecodeBundle) -> bool { + insert_commitments::(&mut self.0.tree, scan_result) + } + + fn calculate_witness(self: &mut OrchardShardTreeBundle, + commitment_tree_position: u32, + checkpoint: u32) -> Box { + match self.0.tree.witness_at_checkpoint_id_caching(( + commitment_tree_position as u64).into(), &checkpoint.into()) { + Ok(witness) => Box::new(OrchardWitnessBundleResult::from( + Ok(OrchardWitnessBundleValue { path: witness.path_elems().into() }))), + Err(_e) => Box::new(OrchardWitnessBundleResult::from(Err(Error::WitnessError))) + } + } + + fn truncate(self: &mut OrchardShardTreeBundle, checkpoint: u32) -> bool { + self.0.tree.truncate_removing_checkpoint(&BlockHeight::from_u32(checkpoint)).is_ok() + } +} + +impl OrchardTestingShardTreeBundle { + fn insert_commitments(self: &mut OrchardTestingShardTreeBundle, scan_result: &mut BatchOrchardDecodeBundle) -> bool { + insert_commitments::(&mut self.0.tree, scan_result) + } + + fn calculate_witness(self: &mut OrchardTestingShardTreeBundle, + commitment_tree_position: u32, checkpoint: u32) -> Box { + match self.0.tree.witness_at_checkpoint_id_caching((commitment_tree_position as u64).into(), &checkpoint.into()) { + Ok(witness) => Box::new(OrchardWitnessBundleResult::from(Ok(OrchardWitnessBundleValue { path: witness.into() }))), + Err(_e) => Box::new(OrchardWitnessBundleResult::from(Err(Error::WitnessError))) + } + } + + fn truncate(self: &mut OrchardTestingShardTreeBundle, checkpoint: u32) -> bool { + let result = self.0.tree.truncate_removing_checkpoint(&BlockHeight::from_u32(checkpoint)); + return result.is_ok() && result.unwrap(); + } +} + +impl OrchardWitnessBundle { + fn size(self: &OrchardWitnessBundle) -> usize { + self.0.path.0.len() + } + + fn item(self: &OrchardWitnessBundle, index: usize) -> [u8; 32] { + self.0.path.0[index].to_bytes() + } +} + +#[derive(Clone)] +pub struct CxxShardStoreImpl { + native_context: Rc>>, + _hash_type: PhantomData, +} + +impl From<&ShardTreeCheckpoint> for Checkpoint { + fn from(item: &ShardTreeCheckpoint) -> Self { + let tree_state : TreeState = + if item.empty { TreeState::Empty } else { TreeState::AtPosition((item.position as u64).into()) }; + let marks_removed : BTreeSet = + item.mark_removed.iter().map(|x| Position::from(*x as u64)).collect(); + Checkpoint::from_parts(tree_state, marks_removed) + } +} + +impl TryFrom<&Checkpoint> for ShardTreeCheckpoint { + type Error = Error; + fn try_from(item: &Checkpoint) -> Result { + let position: u32 = match item.tree_state() { + TreeState::Empty => 0, + TreeState::AtPosition(pos) => (u64::from(pos)).try_into().map_err(|_| Error::ShardStoreError)? + }; + let marks_removed : Result, Error> = item.marks_removed().into_iter().map( + |x| u32::try_from(u64::from(*x)).map_err(|_| Error::ShardStoreError)).collect(); + Ok(ShardTreeCheckpoint { + empty: item.is_tree_empty(), + position: position, + mark_removed: marks_removed? + }) + } +} + +impl TryFrom<&Address> for ShardTreeAddress { + type Error = Error; + + fn try_from(item: &Address) -> Result { + let index : u32 = item.index().try_into().map_err(|_| Error::ShardStoreError)?; + Ok(ShardTreeAddress{ + level: item.level().into(), + index: index }) + } +} + +impl From<&ShardTreeAddress> for Address { + fn from(item: &ShardTreeAddress) -> Self { + Address::from_parts(item.level.into(), item.index.into()) + } +} + +impl TryFrom<&ShardTreeShard> for LocatedPrunableTree { + type Error = Error; + + fn try_from(item: &ShardTreeShard) -> Result { + let shard_tree = + read_shard(&mut Cursor::new(&item.data)).map_err(|_| Error::ShardStoreError)?; + let located_tree: LocatedTree<_, (_, RetentionFlags)> = + LocatedPrunableTree::from_parts(Address::from(&item.address), shard_tree); + if !item.hash.is_empty() { + let root_hash = H::read(Cursor::new(item.hash.clone())).map_err(|_| Error::ShardStoreError)?; + Ok(located_tree.reannotate_root(Some(Arc::new(root_hash)))) + } else { + Ok(located_tree) + } + } +} + +impl TryFrom<&ShardTreeCap> for PrunableTree { + type Error = Error; + + fn try_from(item: &ShardTreeCap) -> Result { + read_shard(&mut Cursor::new(&item.data)).map_err(|_| Error::ShardStoreError) + } +} + +impl TryFrom<&PrunableTree> for ShardTreeCap { + type Error = Error; + + fn try_from(item: &PrunableTree) -> Result { + let mut data = vec![]; + write_shard(&mut data, item).map_err(|_| Error::ShardStoreError)?; + Ok(ShardTreeCap { + data: data + }) + } +} + +impl TryFrom<&LocatedPrunableTree> for ShardTreeShard { + type Error = Error; + + fn try_from(item: &LocatedPrunableTree) -> Result { + let subtree_root_hash : Option> = item + .root() + .annotation() + .and_then(|ann| { + ann.as_ref().map(|rc| { + let mut root_hash = vec![]; + rc.write(&mut root_hash)?; + Ok(root_hash) + }) + }) + .transpose() + .map_err(|_err : std::io::Error| Error::ShardStoreError)?; + + + let mut result = ShardTreeShard { + address: ShardTreeAddress::try_from(&item.root_addr()).map_err(|_| Error::ShardStoreError)?, + hash: subtree_root_hash.unwrap_or_else(|| vec![]).try_into().map_err(|_| Error::ShardStoreError)?, + data: vec![] + }; + + write_shard(&mut result.data, &item.root()).map_err(|_| Error::ShardStoreError)?; + Ok(result) + } +} + +type OrchardCxxShardStoreImpl = CxxShardStoreImpl; +type TestingCxxShardStoreImpl = CxxShardStoreImpl; + +impl ShardStore + for CxxShardStoreImpl +{ + type H = H; + type CheckpointId = BlockHeight; + type Error = Error; + + fn get_shard( + &self, + addr: Address, + ) -> Result>, Self::Error> { + let ctx = self.native_context.clone(); + let mut into = ShardTreeShard::default(); + let result = shard_store_get_shard(&*ctx.try_borrow().unwrap(), + &ShardTreeAddress::try_from(&addr).map_err(|_| Error::ShardStoreError)?, + &mut into); + if result == ShardStoreStatusCode::Ok { + let tree = LocatedPrunableTree::::try_from(&into)?; + return Ok(Some(tree)); + } else if result == ShardStoreStatusCode::None { + return Ok(Option::None); + } else { + return Err(Error::ShardStoreError); + } + } + + fn last_shard(&self) -> Result>, Self::Error> { + let ctx = self.native_context.clone(); + let mut into = ShardTreeShard::default(); + let result = + shard_store_last_shard(&*ctx.try_borrow().unwrap(), &mut into); + if result == ShardStoreStatusCode::Ok { + let tree = LocatedPrunableTree::::try_from(&into)?; + return Ok(Some(tree)); + } else if result == ShardStoreStatusCode::None { + return Ok(Option::None); + } else { + return Err(Error::ShardStoreError); + } + } + + fn put_shard(&mut self, subtree: LocatedPrunableTree) -> Result<(), Self::Error> { + let ctx = self.native_context.clone(); + let shard = ShardTreeShard::try_from(&subtree).map_err(|_| Error::ShardStoreError)?; + let result = + shard_store_put_shard(ctx.try_borrow_mut().unwrap().pin_mut(), &shard); + if result == ShardStoreStatusCode::Ok { + return Ok(()); + } + return Err(Error::ShardStoreError); + } + + fn get_shard_roots(&self) -> Result, Self::Error> { + let ctx = self.native_context.clone(); + let mut input : Vec = vec![]; + let result = shard_store_get_shard_roots(&*ctx.try_borrow().unwrap(), &mut input); + if result == ShardStoreStatusCode::Ok { + return Ok(input.into_iter().map(|res| { + Address::from_parts(res.level.into(), res.index.into()) + }).collect()) + } else if result == ShardStoreStatusCode::None { + return Ok(vec![]) + } else { + return Err(Error::ShardStoreError) + } + } + + fn truncate(&mut self, from: Address) -> Result<(), Self::Error> { + let ctx = self.native_context.clone(); + let result = + shard_store_truncate(ctx.try_borrow_mut().unwrap().pin_mut(), + &ShardTreeAddress::try_from(&from).map_err(|_| Error::ShardStoreError)?); + if result == ShardStoreStatusCode::Ok || result == ShardStoreStatusCode::None { + return Ok(()); + } else { + return Err(Error::ShardStoreError) + } + } + + fn get_cap(&self) -> Result, Self::Error> { + let ctx = self.native_context.clone(); + let mut input = ShardTreeCap::default(); + let result = + shard_store_get_cap(&*ctx.try_borrow().unwrap(), &mut input); + + if result == ShardStoreStatusCode::Ok { + let tree = PrunableTree::::try_from(&input)?; + return Ok(tree) + } else + if result == ShardStoreStatusCode::None { + return Ok(PrunableTree::empty()); + } else { + return Err(Error::ShardStoreError); + } + } + + fn put_cap(&mut self, cap: PrunableTree) -> Result<(), Self::Error> { + let ctx = self.native_context.clone(); + let mut result_cap = ShardTreeCap::default(); + write_shard(&mut result_cap.data, &cap).map_err(|_| Error::ShardStoreError)?; + + let result = + shard_store_put_cap(ctx.try_borrow_mut().unwrap().pin_mut(), &result_cap); + if result == ShardStoreStatusCode::Ok { + return Ok(()); + } + return Err(Error::ShardStoreError); + } + + fn min_checkpoint_id(&self) -> Result, Self::Error> { + let ctx = self.native_context.clone(); + let mut input : u32 = 0; + let result = + shard_store_min_checkpoint_id(&*ctx.try_borrow().unwrap(), &mut input); + if result == ShardStoreStatusCode::Ok { + return Ok(Some(input.into())); + } else if result == ShardStoreStatusCode::None { + return Ok(Option::None); + } + return Err(Error::ShardStoreError); + } + + fn max_checkpoint_id(&self) -> Result, Self::Error> { + let ctx = self.native_context.clone(); + let mut input : u32 = 0; + let result = shard_store_max_checkpoint_id(&*ctx.try_borrow().unwrap(), &mut input); + if result == ShardStoreStatusCode::Ok { + return Ok(Some(input.into())); + } else if result == ShardStoreStatusCode::None { + return Ok(Option::None); + } + return Err(Error::ShardStoreError); } + + fn add_checkpoint( + &mut self, + checkpoint_id: Self::CheckpointId, + checkpoint: Checkpoint, + ) -> Result<(), Self::Error> { + let ctx = self.native_context.clone(); + let ffi_checkpoint_id : u32 = checkpoint_id.try_into().map_err(|_| Error::ShardStoreError)?; + let result = + shard_store_add_checkpoint(ctx.try_borrow_mut().unwrap().pin_mut(), + ffi_checkpoint_id, + &ShardTreeCheckpoint::try_from(&checkpoint)?); + if result == ShardStoreStatusCode::Ok { + return Ok(()); + } + return Err(Error::ShardStoreError); + } + + fn checkpoint_count(&self) -> Result { + let ctx = self.native_context.clone(); + let mut input : usize = 0; + let result = shard_store_checkpoint_count(&*ctx.try_borrow().unwrap(), &mut input); + if result == ShardStoreStatusCode::Ok { + return Ok(input.into()); + } else if result == ShardStoreStatusCode::None { + return Ok(0); + } + return Err(Error::ShardStoreError); + } + + fn get_checkpoint_at_depth( + &self, + checkpoint_depth: usize, + ) -> Result, Self::Error> { + let ctx = self.native_context.clone(); + let mut input_checkpoint_id : u32 = 0; + let mut input_checkpoint : ShardTreeCheckpoint = ShardTreeCheckpoint::default(); + + let result = shard_store_get_checkpoint_at_depth(&*ctx.try_borrow().unwrap(), + checkpoint_depth, + &mut input_checkpoint_id, + &mut input_checkpoint); + + if result == ShardStoreStatusCode::Ok { + return Ok(Some((BlockHeight::from(input_checkpoint_id), Checkpoint::from(&input_checkpoint)))); + } else if result == ShardStoreStatusCode::None { + return Ok(Option::None); + } + return Ok(Option::None); + } + + fn get_checkpoint( + &self, + checkpoint_id: &Self::CheckpointId, + ) -> Result, Self::Error> { + let ctx = self.native_context.clone(); + let mut input_checkpoint : ShardTreeCheckpoint = ShardTreeCheckpoint::default(); + + let result = shard_store_get_checkpoint(&*ctx.try_borrow().unwrap(), + (*checkpoint_id).into(), + &mut input_checkpoint); + + if result == ShardStoreStatusCode::Ok { + return Ok(Some(Checkpoint::from(&input_checkpoint))); + } else if result == ShardStoreStatusCode::None { + return Ok(Option::None); + } + return Ok(Option::None); + } + + fn with_checkpoints(&mut self, limit: usize, mut callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Self::CheckpointId, &Checkpoint) -> Result<(), Self::Error>, + { + let ctx = self.native_context.clone(); + let mut into : Vec = vec![]; + let result = shard_store_get_checkpoints(&*ctx.try_borrow().unwrap(), limit, &mut into); + if result == ShardStoreStatusCode::Ok { + for item in into { + let checkpoint = Checkpoint::from(&item.checkpoint); + callback(&BlockHeight::from(item.checkpoint_id), &checkpoint).map_err(|_| Error::ShardStoreError)?; + } + return Ok(()) + } else if result == ShardStoreStatusCode::None { + return Ok(()) + } + Err(Error::ShardStoreError) + } + + fn update_checkpoint_with( + &mut self, + checkpoint_id: &Self::CheckpointId, + update: F, + ) -> Result + where + F: Fn(&mut Checkpoint) -> Result<(), Self::Error>, + { + let ctx = self.native_context.clone(); + let mut input_checkpoint = ShardTreeCheckpoint::default(); + let result_get_checkpoint = + shard_store_get_checkpoint(&*ctx.try_borrow().unwrap(), (*checkpoint_id).into(), &mut input_checkpoint); + if result_get_checkpoint == ShardStoreStatusCode::Ok { + return Ok(true); + } else if result_get_checkpoint == ShardStoreStatusCode::None { + return Ok(false); + } + + let mut checkpoint = Checkpoint::from(&input_checkpoint); + + update(&mut checkpoint).map_err(|_| Error::ShardStoreError)?; + let result_update_checkpoint = + shard_store_update_checkpoint(ctx.try_borrow_mut().unwrap().pin_mut(), + (*checkpoint_id).into(), &ShardTreeCheckpoint::try_from(&checkpoint)?); + if result_update_checkpoint == ShardStoreStatusCode::Ok { + return Ok(true); + } else if result_update_checkpoint == ShardStoreStatusCode::None { + return Ok(false); + } + return Err(Error::ShardStoreError); + } + + fn remove_checkpoint(&mut self, checkpoint_id: &Self::CheckpointId) -> Result<(), Self::Error> { + let ctx = self.native_context.clone(); + let result = + shard_store_remove_checkpoint(ctx.try_borrow_mut().unwrap().pin_mut(), (*checkpoint_id).into()); + if result == ShardStoreStatusCode::Ok { + return Ok(()); + } else if result == ShardStoreStatusCode::None { + return Ok(()); + } + return Err(Error::ShardStoreError); + } + + fn truncate_checkpoints( + &mut self, + checkpoint_id: &Self::CheckpointId, + ) -> Result<(), Self::Error> { + let ctx = self.native_context.clone(); + let result = + shard_store_truncate_checkpoint(ctx.try_borrow_mut().unwrap().pin_mut(), (*checkpoint_id).into()); + if result == ShardStoreStatusCode::Ok { + return Ok(()); + } else if result == ShardStoreStatusCode::None { + return Ok(()); + } + return Err(Error::ShardStoreError); + } +} + +fn create_shard_tree(context: UniquePtr) -> Box { + let shard_store = OrchardCxxShardStoreImpl { + native_context: Rc::new(RefCell::new(context)), + _hash_type: Default::default() + }; + let shardtree = ShardTree::new(shard_store, PRUNING_DEPTH.try_into().unwrap()); + Box::new(OrchardShardTreeBundleResult::from(Ok(OrchardShardTreeBundleValue{tree: shardtree}))) +} + +fn convert_ffi_commitments(shard_tree_leafs: &ShardTreeLeafs) -> Vec<(MerkleHashOrchard, Retention)> { + shard_tree_leafs.commitments.iter().map(|c| { + let retention:Retention = { + if c.retention.checkpoint { + Retention::Checkpoint { id: c.retention.checkpoint_id.into(), is_marked: c.retention.marked } + } else if c.retention.marked { + Retention::Marked + } else { + Retention::Ephemeral + } + }; + let mh = MerkleHashOrchard::from_bytes(&c.hash); + (mh.unwrap(), retention) + }).collect() +} + +fn create_mock_decode_result(prior_tree_state: ShardTreeState, commitments: ShardTreeLeafs) -> Box { + Box::new(BatchOrchardDecodeBundleResult::from(Ok(BatchOrchardDecodeBundleValue { + fvk: [0; 96], + outputs: vec![], + commitments: convert_ffi_commitments(&commitments), + prior_tree_state: prior_tree_state + }))) +} + +fn create_testing_shard_tree(context: UniquePtr) -> Box { + let shard_store = TestingCxxShardStoreImpl { + native_context: Rc::new(RefCell::new(context)), + _hash_type: Default::default() + }; + let shardtree = ShardTree::new(shard_store, TESTING_PRUNING_DEPTH.try_into().unwrap()); + Box::new(OrchardTestingShardTreeBundleResult::from(Ok(OrchardTestingShardTreeBundleValue{tree: shardtree}))) } +fn create_mock_commitment(position: u32, seed: u32) -> [u8; 32] { + MerkleHashOrchard::from_bytes( + &(pallas::Base::random(MockRng((position * seed).into())).into())).unwrap().to_bytes() +} \ No newline at end of file diff --git a/components/brave_wallet/browser/zcash/rust/librustzcash/BUILD.gn b/components/brave_wallet/browser/zcash/rust/librustzcash/BUILD.gn index 9fba4ac46629..1941eae34851 100644 --- a/components/brave_wallet/browser/zcash/rust/librustzcash/BUILD.gn +++ b/components/brave_wallet/browser/zcash/rust/librustzcash/BUILD.gn @@ -6,7 +6,11 @@ import("//build/rust/rust_static_library.gni") rust_static_library("zcash_protocol") { - visibility = [ ":zcash_primitives" ] + visibility = [ + ":zcash_primitives", + "//brave/components/brave_wallet/browser/zcash/rust:rust_lib", + "//brave/components/brave_wallet/browser/zcash/rust:rust_lib_cxx_generated", + ] crate_name = "zcash_protocol" crate_root = "src/components/zcash_protocol/src/lib.rs" sources = [ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_block_decoder.h b/components/brave_wallet/browser/zcash/rust/orchard_block_decoder.h index 6e4d3a382ebf..dc9f89dfc642 100644 --- a/components/brave_wallet/browser/zcash/rust/orchard_block_decoder.h +++ b/components/brave_wallet/browser/zcash/rust/orchard_block_decoder.h @@ -10,6 +10,7 @@ #include #include +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h" #include "brave/components/brave_wallet/common/zcash_utils.h" #include "brave/components/services/brave_wallet/public/mojom/zcash_decoder.mojom.h" @@ -19,8 +20,10 @@ class OrchardBlockDecoder { public: virtual ~OrchardBlockDecoder() = default; - virtual std::optional> ScanBlock( - const ::brave_wallet::zcash::mojom::CompactBlockPtr& block) = 0; + virtual std::unique_ptr ScanBlocks( + const ::brave_wallet::OrchardTreeState& tree_state, + const std::vector<::brave_wallet::zcash::mojom::CompactBlockPtr>& + blocks) = 0; static std::unique_ptr FromFullViewKey( const OrchardFullViewKey& fvk); diff --git a/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.cc b/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.cc index 69afd446414f..9bcb39bf3b26 100644 --- a/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.cc +++ b/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.cc @@ -11,6 +11,7 @@ #include "base/memory/ptr_util.h" #include "brave/components/brave_wallet/browser/zcash/rust/lib.rs.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h" #include "brave/components/brave_wallet/common/zcash_utils.h" namespace brave_wallet::orchard { @@ -20,51 +21,61 @@ OrchardBlockDecoderImpl::OrchardBlockDecoderImpl(const OrchardFullViewKey& fvk) OrchardBlockDecoderImpl::~OrchardBlockDecoderImpl() = default; -std::optional> -OrchardBlockDecoderImpl::ScanBlock( - const ::brave_wallet::zcash::mojom::CompactBlockPtr& block) { - std::vector result; - for (const auto& tx : block->vtx) { - ::rust::Vec orchard_actions; - for (const auto& orchard_action : tx->orchard_actions) { - orchard::OrchardCompactAction orchard_compact_action; - - if (orchard_action->nullifier.size() != kOrchardNullifierSize || - orchard_action->cmx.size() != kOrchardCmxSize || - orchard_action->ephemeral_key.size() != kOrchardEphemeralKeySize || - orchard_action->ciphertext.size() != kOrchardCipherTextSize) { - return std::nullopt; - } +std::unique_ptr OrchardBlockDecoderImpl::ScanBlocks( + const ::brave_wallet::OrchardTreeState& tree_state, + const std::vector<::brave_wallet::zcash::mojom::CompactBlockPtr>& blocks) { + ::rust::Vec orchard_actions; + for (const auto& block : blocks) { + bool block_has_orchard_action = false; + for (const auto& tx : block->vtx) { + for (const auto& orchard_action : tx->orchard_actions) { + block_has_orchard_action = true; + orchard::OrchardCompactAction orchard_compact_action; - base::ranges::copy(orchard_action->nullifier, - orchard_compact_action.nullifier.begin()); - base::ranges::copy(orchard_action->cmx, - orchard_compact_action.cmx.begin()); - base::ranges::copy(orchard_action->ephemeral_key, - orchard_compact_action.ephemeral_key.begin()); - base::ranges::copy(orchard_action->ciphertext, - orchard_compact_action.enc_cipher_text.begin()); + if (orchard_action->nullifier.size() != kOrchardNullifierSize || + orchard_action->cmx.size() != kOrchardCmxSize || + orchard_action->ephemeral_key.size() != kOrchardEphemeralKeySize || + orchard_action->ciphertext.size() != kOrchardCipherTextSize) { + return nullptr; + } - orchard_actions.emplace_back(std::move(orchard_compact_action)); - } + orchard_compact_action.block_id = block->height; + orchard_compact_action.is_block_last_action = false; + base::ranges::copy(orchard_action->nullifier, + orchard_compact_action.nullifier.begin()); + base::ranges::copy(orchard_action->cmx, + orchard_compact_action.cmx.begin()); + base::ranges::copy(orchard_action->ephemeral_key, + orchard_compact_action.ephemeral_key.begin()); + base::ranges::copy(orchard_action->ciphertext, + orchard_compact_action.enc_cipher_text.begin()); - ::rust::Box<::brave_wallet::orchard::BatchOrchardDecodeBundleResult> - decode_result = ::brave_wallet::orchard::batch_decode( - full_view_key_, std::move(orchard_actions)); - - if (decode_result->is_ok()) { - ::rust::Box<::brave_wallet::orchard::BatchOrchardDecodeBundle> - result_bundle = decode_result->unwrap(); - for (size_t i = 0; i < result_bundle->size(); i++) { - result.emplace_back(OrchardNote( - {block->height, result_bundle->note_nullifier(full_view_key_, i), - result_bundle->note_value(i)})); + orchard_actions.push_back(std::move(orchard_compact_action)); } - } else { - return std::nullopt; } + if (block_has_orchard_action) { + orchard_actions.back().is_block_last_action = true; + } + } + + ::brave_wallet::orchard::ShardTreeState prior_tree_state; + prior_tree_state.block_height = tree_state.block_height; + prior_tree_state.tree_size = tree_state.tree_size; + + base::ranges::copy(tree_state.frontier, + std::back_inserter(prior_tree_state.frontier)); + + ::rust::Box<::brave_wallet::orchard::BatchOrchardDecodeBundleResult> + decode_result = ::brave_wallet::orchard::batch_decode( + full_view_key_, std::move(prior_tree_state), + std::move(orchard_actions)); + + if (decode_result->is_ok()) { + return base::WrapUnique( + new OrchardDecodedBlocksBundleImpl(decode_result->unwrap())); + } else { + return nullptr; } - return result; } // static diff --git a/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.h b/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.h index f7cbb60a5d6a..fa2c75f3b86c 100644 --- a/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.h +++ b/components/brave_wallet/browser/zcash/rust/orchard_block_decoder_impl.h @@ -6,9 +6,11 @@ #ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_BLOCK_DECODER_IMPL_H_ #define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_BLOCK_DECODER_IMPL_H_ +#include #include #include "brave/components/brave_wallet/browser/zcash/rust/orchard_block_decoder.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h" namespace brave_wallet::orchard { @@ -16,8 +18,10 @@ class OrchardBlockDecoderImpl : public OrchardBlockDecoder { public: ~OrchardBlockDecoderImpl() override; - std::optional> ScanBlock( - const ::brave_wallet::zcash::mojom::CompactBlockPtr& block) override; + std::unique_ptr ScanBlocks( + const ::brave_wallet::OrchardTreeState& tree_state, + const std::vector<::brave_wallet::zcash::mojom::CompactBlockPtr>& block) + override; private: friend class OrchardBlockDecoder; diff --git a/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h b/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h new file mode 100644 index 000000000000..b0b528efdefa --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_DECODED_BLOCKS_BUNDE_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_DECODED_BLOCKS_BUNDE_H_ + +#include +#include + +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet::orchard { + +class OrchardDecodedBlocksBundle { + public: + // Builder is used in tests to create OrchardDecodedBlocksBundle with mocked + // commitments + class TestingBuilder { + public: + TestingBuilder() = default; + virtual ~TestingBuilder() = default; + virtual void AddCommitment( + const ::brave_wallet::OrchardCommitment& commitment) = 0; + virtual void SetPriorTreeState( + const ::brave_wallet::OrchardTreeState& tree_state) = 0; + virtual std::unique_ptr Complete() = 0; + }; + + static std::unique_ptr CreateTestingBuilder(); + + virtual ~OrchardDecodedBlocksBundle() {} + virtual std::optional> + GetDiscoveredNotes() = 0; +}; + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_DECODED_BLOCKS_BUNDE_H_ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.cc b/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.cc new file mode 100644 index 000000000000..929b61c24e05 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h" + +#include +#include + +#include "base/check_is_test.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet::orchard { + +class TestingBuilderImpl : public OrchardDecodedBlocksBundle::TestingBuilder { + public: + TestingBuilderImpl() {} + + ~TestingBuilderImpl() override {} + + void SetPriorTreeState( + const ::brave_wallet::OrchardTreeState& tree_state) override { + prior_tree_state_ = tree_state; + } + + void AddCommitment( + const ::brave_wallet::OrchardCommitment& commitment) override { + ShardTreeCheckpointRetention retention; + retention.marked = commitment.is_marked; + retention.checkpoint = commitment.checkpoint_id.has_value(); + retention.checkpoint_id = commitment.checkpoint_id.value_or(0); + + ShardTreeLeaf leaf; + leaf.hash = commitment.cmu; + leaf.retention = retention; + + leafs_.commitments.push_back(std::move(leaf)); + } + + std::unique_ptr Complete() override { + ::rust::Vec frontier; + base::ranges::copy(prior_tree_state_->frontier, + std::back_inserter(frontier)); + auto prior_tree_state = + ShardTreeState{frontier, prior_tree_state_->block_height, + prior_tree_state_->tree_size}; + return base::WrapUnique( + new OrchardDecodedBlocksBundleImpl( + create_mock_decode_result(std::move(prior_tree_state), + std::move(leafs_)) + ->unwrap())); + } + + private: + std::optional<::brave_wallet::OrchardTreeState> prior_tree_state_; + ShardTreeLeafs leafs_; +}; + +OrchardDecodedBlocksBundleImpl::OrchardDecodedBlocksBundleImpl( + rust::Box v) + : batch_decode_result_(std::move(v)) {} + +OrchardDecodedBlocksBundleImpl::~OrchardDecodedBlocksBundleImpl() {} + +std::optional> +OrchardDecodedBlocksBundleImpl::GetDiscoveredNotes() { + std::vector result; + + for (size_t i = 0; i < batch_decode_result_->size(); i++) { + result.emplace_back(OrchardNote({ + batch_decode_result_->note_addr(i), + batch_decode_result_->note_block_height(i), + batch_decode_result_->note_nullifier(i), + batch_decode_result_->note_value(i), + batch_decode_result_->note_commitment_tree_position(i), + batch_decode_result_->note_rho(i), + batch_decode_result_->note_rseed(i), + })); + } + + return result; +} + +BatchOrchardDecodeBundle& OrchardDecodedBlocksBundleImpl::GetDecodeBundle() { + return *batch_decode_result_; +} + +// static +std::unique_ptr +OrchardDecodedBlocksBundle::CreateTestingBuilder() { + CHECK_IS_TEST(); + return base::WrapUnique(new TestingBuilderImpl()); +} + +} // namespace brave_wallet::orchard diff --git a/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h b/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h new file mode 100644 index 000000000000..350784ba754d --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_DECODED_BLOCKS_BUNDE_IMPL_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_DECODED_BLOCKS_BUNDE_IMPL_H_ + +#include + +#include "brave/components/brave_wallet/browser/zcash/rust/lib.rs.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h" +#include "third_party/rust/cxx/v1/cxx.h" + +namespace brave_wallet::orchard { + +class OrchardDecodedBlocksBundleImpl : public OrchardDecodedBlocksBundle { + public: + explicit OrchardDecodedBlocksBundleImpl(rust::Box); + ~OrchardDecodedBlocksBundleImpl() override; + + std::optional> GetDiscoveredNotes() + override; + BatchOrchardDecodeBundle& GetDecodeBundle(); + + private: + rust::Box batch_decode_result_; +}; + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_DECODED_BLOCKS_BUNDE_IMPL_H_ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_shard_tree.h b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree.h new file mode 100644 index 000000000000..f2c1f8db7bfc --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_SHARD_TREE_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_SHARD_TREE_H_ + +#include +#include +#include + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet::orchard { + +class OrchardShardTree { + public: + virtual ~OrchardShardTree() {} + + virtual bool TruncateToCheckpoint(uint32_t checkpoint_id) = 0; + + virtual bool ApplyScanResults( + std::unique_ptr commitments) = 0; + + virtual base::expected CalculateWitness( + uint32_t note_commitment_tree_position, + uint32_t checkpoint) = 0; + + static std::unique_ptr Create( + std::unique_ptr<::brave_wallet::OrchardShardTreeDelegate> delegate); + + static std::unique_ptr CreateForTesting( + std::unique_ptr<::brave_wallet::OrchardShardTreeDelegate> delegate); +}; + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_SHARD_TREE_H_ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.cc b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.cc new file mode 100644 index 000000000000..b16179734088 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.cc @@ -0,0 +1,362 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.h" + +#include +#include +#include +#include + +#include "base/memory/ptr_util.h" +#include "base/ranges/algorithm.h" +#include "brave/components/brave_wallet/browser/zcash/rust/cxx/src/shard_store.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h" +#include "brave/components/brave_wallet/common/hex_utils.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet::orchard { + +::brave_wallet::OrchardShardAddress From(const ShardTreeAddress& addr) { + return ::brave_wallet::OrchardShardAddress{addr.level, addr.index}; +} + +ShardTreeAddress From(const ::brave_wallet::OrchardShardAddress& addr) { + return ShardTreeAddress{addr.level, addr.index}; +} + +ShardTreeCap From(::brave_wallet::OrchardCap& shard_store_cap) { + ::rust::Vec data; + data.reserve(shard_store_cap.data.size()); + base::ranges::copy(shard_store_cap.data, std::back_inserter(data)); + return ShardTreeCap{std::move(data)}; +} + +::brave_wallet::OrchardCap From(const ShardTreeCap& cap) { + ::brave_wallet::OrchardCap shard_store_cap; + shard_store_cap.data.reserve(cap.data.size()); + base::ranges::copy(cap.data, std::back_inserter(shard_store_cap.data)); + return shard_store_cap; +} + +::brave_wallet::OrchardShard From(const ShardTreeShard& tree) { + std::optional shard_root_hash; + if (!tree.hash.empty()) { + CHECK_EQ(kOrchardShardTreeHashSize, tree.hash.size()); + OrchardShardRootHash hash_value; + base::ranges::copy(tree.hash, hash_value.begin()); + shard_root_hash = hash_value; + } + + std::vector data; + data.reserve(tree.data.size()); + base::ranges::copy(tree.data, std::back_inserter(data)); + + return ::brave_wallet::OrchardShard(From(tree.address), shard_root_hash, + std::move(data)); +} + +ShardTreeShard From(const ::brave_wallet::OrchardShard& tree) { + ::rust::Vec data; + data.reserve(tree.shard_data.size()); + base::ranges::copy(tree.shard_data, std::back_inserter(data)); + + ::rust::Vec hash; + if (tree.root_hash) { + base::ranges::copy(tree.root_hash.value(), std::back_inserter(hash)); + } + return ShardTreeShard{From(tree.address), std::move(hash), std::move(data)}; +} + +ShardTreeCheckpoint From(const ::brave_wallet::OrchardCheckpoint& checkpoint) { + ::rust::Vec marks_removed; + base::ranges::copy(checkpoint.marks_removed, + std::back_inserter(marks_removed)); + return ShardTreeCheckpoint{!checkpoint.tree_state_position.has_value(), + checkpoint.tree_state_position.value_or(0), + marks_removed}; +} + +ShardTreeCheckpointBundle From( + const ::brave_wallet::OrchardCheckpointBundle& checkpoint_bundle) { + return ShardTreeCheckpointBundle(checkpoint_bundle.checkpoint_id, + From(checkpoint_bundle.checkpoint)); +} + +::brave_wallet::OrchardCheckpoint From(const ShardTreeCheckpoint& checkpoint) { + CheckpointTreeState checkpoint_tree_state = std::nullopt; + if (!checkpoint.empty) { + checkpoint_tree_state = checkpoint.position; + } + return ::brave_wallet::OrchardCheckpoint{ + checkpoint_tree_state, + std::vector(checkpoint.mark_removed.begin(), + checkpoint.mark_removed.end())}; +} + +ShardStoreStatusCode shard_store_get_shard(const ShardStoreContext& ctx, + const ShardTreeAddress& addr, + ShardTreeShard& input) { + auto shard = ctx.GetShard(From(addr)); + if (!shard.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!shard.value()) { + return ShardStoreStatusCode::None; + } + input = From(**shard); + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_last_shard(const ShardStoreContext& ctx, + ShardTreeShard& input) { + auto shard = ctx.LastShard(4); + if (!shard.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!shard.value()) { + return ShardStoreStatusCode::None; + } + input = From(**shard); + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_put_shard(ShardStoreContext& ctx, + const ShardTreeShard& tree) { + auto result = ctx.PutShard(From(tree)); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_get_shard_roots( + const ShardStoreContext& ctx, + ::rust::Vec& input) { + auto shard = ctx.GetShardRoots(4); + if (!shard.has_value()) { + return ShardStoreStatusCode::Error; + } + for (const auto& root : *shard) { + input.push_back(From(root)); + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_truncate(ShardStoreContext& ctx, + const ShardTreeAddress& address) { + auto result = ctx.Truncate(address.index); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_get_cap(const ShardStoreContext& ctx, + ShardTreeCap& input) { + auto result = ctx.GetCap(); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + input = From(**result); + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_put_cap(ShardStoreContext& ctx, + const ShardTreeCap& tree) { + auto result = ctx.PutCap(From(tree)); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_min_checkpoint_id(const ShardStoreContext& ctx, + uint32_t& input) { + auto result = ctx.MinCheckpointId(); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + input = **result; + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_max_checkpoint_id(const ShardStoreContext& ctx, + uint32_t& input) { + auto result = ctx.MaxCheckpointId(); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + input = **result; + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_add_checkpoint( + ShardStoreContext& ctx, + uint32_t checkpoint_id, + const ShardTreeCheckpoint& checkpoint) { + auto result = ctx.AddCheckpoint(checkpoint_id, From(checkpoint)); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_checkpoint_count(const ShardStoreContext& ctx, + size_t& into) { + auto result = ctx.CheckpointCount(); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } + into = *result; + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_get_checkpoint_at_depth( + const ShardStoreContext& ctx, + size_t depth, + uint32_t& into_checkpoint_id, + ShardTreeCheckpoint& into_checkpoint) { + auto checkpoint_id = ctx.GetCheckpointAtDepth(depth); + if (!checkpoint_id.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!checkpoint_id.value()) { + return ShardStoreStatusCode::None; + } + into_checkpoint_id = **checkpoint_id; + + auto checkpoint = ctx.GetCheckpoint(into_checkpoint_id); + if (!checkpoint.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!checkpoint.value()) { + return ShardStoreStatusCode::None; + } + into_checkpoint = From((**checkpoint).checkpoint); + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_get_checkpoint(const ShardStoreContext& ctx, + uint32_t checkpoint_id, + ShardTreeCheckpoint& input) { + auto checkpoint = ctx.GetCheckpoint(checkpoint_id); + if (!checkpoint.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!checkpoint.value()) { + return ShardStoreStatusCode::None; + } + input = From((**checkpoint).checkpoint); + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_update_checkpoint( + ShardStoreContext& ctx, + uint32_t checkpoint_id, + const ShardTreeCheckpoint& checkpoint) { + auto result = ctx.UpdateCheckpoint(checkpoint_id, From(checkpoint)); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_remove_checkpoint(ShardStoreContext& ctx, + uint32_t checkpoint_id) { + auto result = ctx.RemoveCheckpoint(checkpoint_id); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_truncate_checkpoint(ShardStoreContext& ctx, + uint32_t checkpoint_id) { + auto result = ctx.TruncateCheckpoints(checkpoint_id); + if (!result.has_value()) { + return ShardStoreStatusCode::Error; + } else if (!result.value()) { + return ShardStoreStatusCode::None; + } + return ShardStoreStatusCode::Ok; +} + +ShardStoreStatusCode shard_store_get_checkpoints( + const ShardStoreContext& ctx, + size_t limit, + ::rust::Vec& into) { + auto checkpoints = ctx.GetCheckpoints(limit); + if (!checkpoints.has_value()) { + return ShardStoreStatusCode::Error; + } + if (checkpoints->empty()) { + return ShardStoreStatusCode::None; + } + for (const auto& checkpoint : checkpoints.value()) { + into.push_back(From(checkpoint)); + } + return ShardStoreStatusCode::Ok; +} + +bool OrchardShardTreeImpl::ApplyScanResults( + std::unique_ptr commitments) { + auto* bundle_impl = + static_cast(commitments.get()); + return orcard_shard_tree_->insert_commitments(bundle_impl->GetDecodeBundle()); +} + +base::expected +OrchardShardTreeImpl::CalculateWitness(uint32_t note_commitment_tree_position, + uint32_t checkpoint) { + auto result = orcard_shard_tree_->calculate_witness( + note_commitment_tree_position, checkpoint); + if (!result->is_ok()) { + return base::unexpected(result->error_message().c_str()); + } + + auto value = result->unwrap(); + + OrchardNoteWitness witness; + witness.position = note_commitment_tree_position; + for (size_t i = 0; i < value->size(); i++) { + witness.merkle_path.push_back(value->item(i)); + } + + return witness; +} + +bool OrchardShardTreeImpl::TruncateToCheckpoint(uint32_t checkpoint_id) { + return orcard_shard_tree_->truncate(checkpoint_id); +} + +OrchardShardTreeImpl::OrchardShardTreeImpl( + rust::Box orcard_shard_tree) + : orcard_shard_tree_(std::move(orcard_shard_tree)) {} + +OrchardShardTreeImpl::~OrchardShardTreeImpl() {} + +// static +std::unique_ptr OrchardShardTree::Create( + std::unique_ptr<::brave_wallet::OrchardShardTreeDelegate> delegate) { + auto shard_tree_result = + ::brave_wallet::orchard::create_shard_tree(std::move(delegate)); + if (!shard_tree_result->is_ok()) { + return nullptr; + } + return base::WrapUnique( + new OrchardShardTreeImpl(shard_tree_result->unwrap())); +} + +} // namespace brave_wallet::orchard diff --git a/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.h b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.h new file mode 100644 index 000000000000..e7d29828631f --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_impl.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_SHARD_TREE_IMPL_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_SHARD_TREE_IMPL_H_ + +#include +#include + +#include "brave/components/brave_wallet/browser/zcash/rust/lib.rs.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_shard_tree.h" + +namespace brave_wallet::orchard { + +class OrchardShardTreeImpl : public OrchardShardTree { + public: + explicit OrchardShardTreeImpl( + rust::Box orcard_shard_tree); + ~OrchardShardTreeImpl() override; + + bool TruncateToCheckpoint(uint32_t checkpoint_id) override; + + bool ApplyScanResults( + std::unique_ptr commitments) override; + + base::expected CalculateWitness( + uint32_t note_commitment_tree_position, + uint32_t checkpoint) override; + + private: + ::rust::Box orcard_shard_tree_; +}; + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_SHARD_TREE_IMPL_H_ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_unittest.cc b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_unittest.cc new file mode 100644 index 000000000000..1b0620ae8dc6 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_shard_tree_unittest.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include + +#include "base/files/scoped_temp_dir.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.h" +#include "brave/components/brave_wallet/browser/zcash/zcash_test_utils.h" + +static_assert(BUILDFLAG(ENABLE_ORCHARD)); + +namespace brave_wallet { + +class OrchardShardTreeUnitTest : public testing::Test { + public: + OrchardShardTreeUnitTest() {} + + ~OrchardShardTreeUnitTest() override = default; + + void SetUp() override { + ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); + base::FilePath db_path( + temp_dir_.GetPath().Append(FILE_PATH_LITERAL("orchard.db"))); + storage_ = base::MakeRefCounted(path_to_database); + shard_tree_ = OrchardTestingShardTreeImpl::Create( + td::make_unique(account_id.Clone(), + storage_)); + } + + protected: + base::test::TaskEnvironment task_environment_; + base::ScopedTempDir temp_dir_; + + scoped_refptr storage_; + std::unique_ptr shard_tree_; +}; + +TEST_F(OrchardShardTreeUnitTest, DiscoverNewNotes) {} + +} // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/rust/orchard_test_utils.h b/components/brave_wallet/browser/zcash/rust/orchard_test_utils.h new file mode 100644 index 000000000000..e7aa5c783909 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_test_utils.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TEST_UTILS_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TEST_UTILS_H_ + +#include + +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet::orchard { + +class OrchardTestUtils { + public: + virtual ~OrchardTestUtils() = default; + + virtual OrchardCommitmentValue CreateMockCommitmentValue(uint32_t position, + uint32_t rseed) = 0; + + static std::unique_ptr Create(); +}; + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TEST_UTILS_H_ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.cc b/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.cc new file mode 100644 index 000000000000..b1ccd2acc7fd --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.h" + +#include + +#include "base/memory/ptr_util.h" +#include "brave/components/brave_wallet/browser/zcash/rust/lib.rs.h" + +namespace brave_wallet::orchard { + +OrchardTestUtilsImpl::OrchardTestUtilsImpl() {} + +OrchardTestUtilsImpl::~OrchardTestUtilsImpl() {} + +OrchardCommitmentValue OrchardTestUtilsImpl::CreateMockCommitmentValue( + uint32_t position, + uint32_t rseed) { + return create_mock_commitment(position, rseed); +} + +// static +std::unique_ptr OrchardTestUtils::Create() { + return base::WrapUnique(new OrchardTestUtilsImpl()); +} + +} // namespace brave_wallet::orchard diff --git a/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.h b/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.h new file mode 100644 index 000000000000..78d04baed151 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_test_utils_impl.h @@ -0,0 +1,24 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TEST_UTILS_IMPL_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TEST_UTILS_IMPL_H_ + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_test_utils.h" + +namespace brave_wallet::orchard { + +class OrchardTestUtilsImpl : public OrchardTestUtils { + public: + OrchardTestUtilsImpl(); + ~OrchardTestUtilsImpl() override; + + OrchardCommitmentValue CreateMockCommitmentValue(uint32_t position, + uint32_t rseed) override; +}; + +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TEST_UTILS_IMPL_H_ diff --git a/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.cc b/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.cc new file mode 100644 index 000000000000..fb438b97a5ee --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.h" + +#include + +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_decoded_blocks_bunde_impl.h" + +namespace brave_wallet::orchard { + +bool OrchardTestingShardTreeImpl::ApplyScanResults( + std::unique_ptr commitments) { + auto* bundle_impl = + static_cast(commitments.get()); + return orcard_shard_tree_->insert_commitments(bundle_impl->GetDecodeBundle()); +} + +base::expected +OrchardTestingShardTreeImpl::CalculateWitness( + uint32_t note_commitment_tree_position, + uint32_t checkpoint) { + auto result = orcard_shard_tree_->calculate_witness( + note_commitment_tree_position, checkpoint); + if (!result->is_ok()) { + return base::unexpected(result->error_message().c_str()); + } + + auto value = result->unwrap(); + + OrchardNoteWitness witness; + witness.position = note_commitment_tree_position; + for (size_t i = 0; i < value->size(); i++) { + witness.merkle_path.push_back(value->item(i)); + } + + return witness; +} + +bool OrchardTestingShardTreeImpl::TruncateToCheckpoint(uint32_t checkpoint_id) { + return orcard_shard_tree_->truncate(checkpoint_id); +} + +OrchardTestingShardTreeImpl::OrchardTestingShardTreeImpl( + rust::Box orcard_shard_tree) + : orcard_shard_tree_(std::move(orcard_shard_tree)) {} + +OrchardTestingShardTreeImpl::~OrchardTestingShardTreeImpl() {} + +// static +std::unique_ptr OrchardShardTree::CreateForTesting( + std::unique_ptr<::brave_wallet::OrchardShardTreeDelegate> delegate) { + auto shard_tree_result = + ::brave_wallet::orchard::create_testing_shard_tree(std::move(delegate)); + if (!shard_tree_result->is_ok()) { + return nullptr; + } + return base::WrapUnique( + new OrchardTestingShardTreeImpl(shard_tree_result->unwrap())); +} + +} // namespace brave_wallet::orchard diff --git a/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.h b/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.h new file mode 100644 index 000000000000..b4039cf42286 --- /dev/null +++ b/components/brave_wallet/browser/zcash/rust/orchard_testing_shard_tree_impl.h @@ -0,0 +1,38 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TESTING_SHARD_TREE_IMPL_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TESTING_SHARD_TREE_IMPL_H_ + +#include +#include +#include + +#include "brave/components/brave_wallet/browser/zcash/rust/lib.rs.h" +#include "brave/components/brave_wallet/browser/zcash/rust/orchard_shard_tree.h" + +namespace brave_wallet::orchard { + +class OrchardTestingShardTreeImpl : public OrchardShardTree { + public: + OrchardTestingShardTreeImpl( + rust::Box orcard_shard_tree); + ~OrchardTestingShardTreeImpl() override; + + bool TruncateToCheckpoint(uint32_t checkpoint_id) override; + + bool ApplyScanResults( + std::unique_ptr commitments) override; + + base::expected CalculateWitness( + uint32_t note_commitment_tree_position, + uint32_t checkpoint) override; + + private: + ::rust::Box orcard_shard_tree_; +}; +} // namespace brave_wallet::orchard + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_RUST_ORCHARD_TESTING_SHARD_TREE_IMPL_H_ diff --git a/components/brave_wallet/browser/zcash/rust/unauthorized_orchard_bundle_impl.cc b/components/brave_wallet/browser/zcash/rust/unauthorized_orchard_bundle_impl.cc index 924c4c58cec2..3c4ad485fcd8 100644 --- a/components/brave_wallet/browser/zcash/rust/unauthorized_orchard_bundle_impl.cc +++ b/components/brave_wallet/browser/zcash/rust/unauthorized_orchard_bundle_impl.cc @@ -37,6 +37,7 @@ std::unique_ptr UnauthorizedOrchardBundle::Create( CHECK_IS_TEST(); auto bundle_result = create_testing_orchard_bundle( ::rust::Slice{tree_state.data(), tree_state.size()}, + ::rust::Vec<::brave_wallet::orchard::OrchardSpend>(), std::move(outputs), random_seed_for_testing.value()); if (!bundle_result->is_ok()) { return nullptr; @@ -46,6 +47,7 @@ std::unique_ptr UnauthorizedOrchardBundle::Create( } else { auto bundle_result = create_orchard_bundle( ::rust::Slice{tree_state.data(), tree_state.size()}, + ::rust::Vec<::brave_wallet::orchard::OrchardSpend>(), std::move(outputs)); if (!bundle_result->is_ok()) { return nullptr; diff --git a/components/brave_wallet/browser/zcash/zcash_orchard_storage.cc b/components/brave_wallet/browser/zcash/zcash_orchard_storage.cc index bea195b8c843..faa8182c3a6b 100644 --- a/components/brave_wallet/browser/zcash/zcash_orchard_storage.cc +++ b/components/brave_wallet/browser/zcash/zcash_orchard_storage.cc @@ -16,7 +16,6 @@ #include "brave/components/brave_wallet/common/hex_utils.h" #include "sql/database.h" #include "sql/meta_table.h" -#include "sql/statement.h" #include "sql/transaction.h" namespace brave_wallet { @@ -25,6 +24,10 @@ namespace { #define kNotesTable "notes" #define kSpentNotesTable "spent_notes" #define kAccountMeta "account_meta" +#define kShardTree "shard_tree" +#define kShardTreeCap "shard_tree_cap" +#define kShardTreeCheckpoints "checkpoints" +#define kCheckpointsMarksRemoved "checkpoints_mark_removed" const int kEmptyDbVersionNumber = 1; const int kCurrentVersionNumber = 2; @@ -37,8 +40,51 @@ std::optional ReadUint32(sql::Statement& statement, size_t index) { return static_cast(v); } +base::expected ReadCheckpointTreeState( + sql::Statement& statement, + size_t index) { + if (statement.GetColumnType(index) == sql::ColumnType::kNull) { + return std::nullopt; + } + auto v = ReadUint32(statement, index); + if (!v) { + return base::unexpected("Format error"); + } + return *v; +} + +// std::optional> ReadBlobData(sql::Statement& statement, +// size_t index) { +// if (statement.GetColumnType(index) == sql::ColumnType::kNull) { +// return std::nullopt; +// } +// auto blob = statement.ColumnBlob(index); +// return std::vector(blob.begin(), blob.end()); +// } + +base::expected, std::string> ReadRootHash( + sql::Statement& statement, + size_t index) { + if (statement.GetColumnType(index) == sql::ColumnType::kNull) { + return std::nullopt; + } + auto v = statement.ColumnBlob(index); + if (v.size() != kOrchardShardTreeHashSize) { + return base::unexpected("Size error"); + } + std::array result; + base::ranges::copy(v.begin(), v.end(), result.begin()); + return result; +} + } // namespace +ZCashOrchardStorage::AccountMeta::AccountMeta() = default; +ZCashOrchardStorage::AccountMeta::~AccountMeta() = default; +ZCashOrchardStorage::AccountMeta::AccountMeta(const AccountMeta&) = default; +ZCashOrchardStorage::AccountMeta& ZCashOrchardStorage::AccountMeta::operator=( + const AccountMeta&) = default; + ZCashOrchardStorage::ZCashOrchardStorage(base::FilePath path_to_database) : db_file_path_(std::move(path_to_database)) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -110,8 +156,12 @@ bool ZCashOrchardStorage::CreateSchema() { "id INTEGER PRIMARY KEY AUTOINCREMENT," "account_id TEXT NOT NULL," "amount INTEGER NOT NULL," + "addr BLOB NOT NULL," "block_id INTEGER NOT NULL," - "nullifier BLOB NOT NULL UNIQUE);") && + "commitment_tree_position INTEGER," + "nullifier BLOB NOT NULL UNIQUE," + "rho BLOB NOT NULL," + "rseed BLOB NOT NULL);") && database_->Execute("CREATE TABLE " kSpentNotesTable " (" "id INTEGER PRIMARY KEY AUTOINCREMENT," @@ -122,8 +172,38 @@ bool ZCashOrchardStorage::CreateSchema() { " (" "account_id TEXT NOT NULL PRIMARY KEY," "account_birthday INTEGER NOT NULL," - "latest_scanned_block INTEGER NOT NULL," - "latest_scanned_block_hash TEXT NOT NULL);") && + "latest_scanned_block INTEGER," + "latest_scanned_block_hash TEXT);") && + database_->Execute( + "CREATE TABLE " kShardTree + " (" + "account_id TEXT NOT NULL," + "shard_index INTEGER NOT NULL," + "subtree_end_height INTEGER," + "root_hash BLOB," + "shard_data BLOB," + "CONSTRAINT shard_index_unique UNIQUE (shard_index, account_id)," + "CONSTRAINT root_unique UNIQUE (root_hash, account_id));") && + database_->Execute("CREATE TABLE " kShardTreeCheckpoints + " (" + "account_id TEXT NOT NULL," + "checkpoint_id INTEGER PRIMARY KEY," + "position INTEGER)") && + database_->Execute("CREATE TABLE " kCheckpointsMarksRemoved + " (" + "account_id TEXT NOT NULL," + "checkpoint_id INTEGER NOT NULL," + "mark_removed_position INTEGER NOT NULL," + "FOREIGN KEY (checkpoint_id) REFERENCES " + "orchard_tree_checkpoints(checkpoint_id)" + "ON DELETE CASCADE," + "CONSTRAINT spend_position_unique UNIQUE " + "(checkpoint_id, mark_removed_position, account_id)" + ")") && + database_->Execute("CREATE TABLE " kShardTreeCap + " (" + "account_id TEXT NOT NULL," + "cap_data BLOB NOT NULL)") && transaction.Commit(); } @@ -133,33 +213,28 @@ bool ZCashOrchardStorage::UpdateSchema() { } base::expected -ZCashOrchardStorage::RegisterAccount( - mojom::AccountIdPtr account_id, - uint32_t account_birthday_block, - const std::string& account_birthday_block_hash) { +ZCashOrchardStorage::RegisterAccount(mojom::AccountIdPtr account_id, + uint32_t account_birthday_block) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!EnsureDbInit()) { return base::unexpected( - Error{ErrorCode::kDbInitError, "Failed to init database "}); + Error{ErrorCode::kDbInitError, "Failed to init database"}); } sql::Transaction transaction(database_.get()); if (!transaction.Begin()) { return base::unexpected( - Error{ErrorCode::kDbInitError, "Failed to init database "}); + Error{ErrorCode::kDbInitError, "Failed to start transaction"}); } sql::Statement register_account_statement(database_->GetCachedStatement( SQL_FROM_HERE, "INSERT INTO " kAccountMeta " " - "(account_id, account_birthday, latest_scanned_block, " - "latest_scanned_block_hash) " - "VALUES (?, ?, ?, ?)")); + "(account_id, account_birthday) " + "VALUES (?, ?)")); register_account_statement.BindString(0, account_id->unique_key); register_account_statement.BindInt64(1, account_birthday_block); - register_account_statement.BindInt64(2, account_birthday_block); - register_account_statement.BindString(3, account_birthday_block_hash); if (!register_account_statement.Run()) { return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, @@ -167,12 +242,13 @@ ZCashOrchardStorage::RegisterAccount( } if (!transaction.Commit()) { - return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + return base::unexpected(Error{ErrorCode::kFailedToCommitTransaction, database_->GetErrorMessage()}); } - return AccountMeta{account_birthday_block, account_birthday_block, - account_birthday_block_hash}; + AccountMeta meta; + meta.account_birthday = account_birthday_block; + return meta; } base::expected @@ -198,16 +274,26 @@ ZCashOrchardStorage::GetAccountMeta(mojom::AccountIdPtr account_id) { AccountMeta account_meta; auto account_birthday = ReadUint32(resolve_account_statement, 0); - auto latest_scanned_block = ReadUint32(resolve_account_statement, 1); - if (!account_birthday || !latest_scanned_block) { + if (!account_birthday) { return base::unexpected( Error{ErrorCode::kInternalError, "Database format error"}); } + account_meta.account_birthday = account_birthday.value(); + + if (resolve_account_statement.GetColumnType(1) != sql::ColumnType::kNull) { + auto latest_scanned_block = ReadUint32(resolve_account_statement, 1); + if (!latest_scanned_block) { + return base::unexpected( + Error{ErrorCode::kInternalError, "Database format error"}); + } + account_meta.latest_scanned_block_id = latest_scanned_block.value(); + } + + if (resolve_account_statement.GetColumnType(2) != sql::ColumnType::kNull) { + account_meta.latest_scanned_block_hash = + resolve_account_statement.ColumnString(2); + } - account_meta.account_birthday = *account_birthday; - account_meta.latest_scanned_block_id = *latest_scanned_block; - account_meta.latest_scanned_block_hash = - resolve_account_statement.ColumnString(2); return account_meta; } @@ -267,7 +353,69 @@ std::optional ZCashOrchardStorage::HandleChainReorg( return std::nullopt; } -base::expected, ZCashOrchardStorage::Error> +base::expected +ZCashOrchardStorage::ResetAccountSyncState(mojom::AccountIdPtr account_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + // Clear cap + sql::Statement clear_cap_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kShardTreeCap " WHERE account_id = ?;")); + clear_cap_statement.BindString(0, account_id->unique_key); + + // Clear shards + sql::Statement clear_shards_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kShardTree " WHERE account_id = ?;")); + clear_shards_statement.BindString(0, account_id->unique_key); + + // Clear discovered notes + sql::Statement clear_discovered_notes(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kNotesTable " WHERE account_id = ?;")); + clear_discovered_notes.BindString(0, account_id->unique_key); + + // Clear spent notes + sql::Statement clear_spent_notes(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kSpentNotesTable " WHERE account_id = ?;")); + clear_spent_notes.BindString(0, account_id->unique_key); + + // Clear checkpoints + sql::Statement clear_checkpoints_statement(database_->GetCachedStatement( + SQL_FROM_HERE, + "DELETE FROM " kShardTreeCheckpoints " WHERE account_id = ?;")); + clear_checkpoints_statement.BindString(0, account_id->unique_key); + + // Update account meta + sql::Statement update_account_meta(database_->GetCachedStatement( + SQL_FROM_HERE, "UPDATE " kAccountMeta " " + "SET latest_scanned_block = NULL, " + "latest_scanned_block_hash = NULL WHERE account_id = ?;")); + update_account_meta.BindString(0, account_id->unique_key); + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + if (!clear_cap_statement.Run() || !clear_shards_statement.Run() || + !clear_discovered_notes.Run() || !clear_spent_notes.Run() || + !clear_checkpoints_statement.Run() || !update_account_meta.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + if (!transaction.Commit()) { + return base::unexpected(Error{ErrorCode::kFailedToCommitTransaction, + database_->GetErrorMessage()}); + } + + return true; +} + +base::expected, ZCashOrchardStorage::Error> ZCashOrchardStorage::GetNullifiers(mojom::AccountIdPtr account_id) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -283,18 +431,18 @@ ZCashOrchardStorage::GetNullifiers(mojom::AccountIdPtr account_id) { resolve_note_spents.BindString(0, account_id->unique_key); - std::vector result; + std::vector result; while (resolve_note_spents.Step()) { - OrchardNullifier nf; + OrchardNoteSpend spend; auto block_id = ReadUint32(resolve_note_spents, 0); if (!block_id) { return base::unexpected( Error{ErrorCode::kDbInitError, "Wrong database format"}); } - nf.block_id = block_id.value(); + spend.block_id = block_id.value(); auto nullifier = resolve_note_spents.ColumnBlob(1); - base::ranges::copy(nullifier, nf.nullifier.begin()); - result.push_back(std::move(nf)); + base::ranges::copy(nullifier, spend.nullifier.begin()); + result.push_back(std::move(spend)); } if (!resolve_note_spents.Succeeded()) { return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, @@ -315,8 +463,9 @@ ZCashOrchardStorage::GetSpendableNotes(mojom::AccountIdPtr account_id) { sql::Statement resolve_unspent_notes(database_->GetCachedStatement( SQL_FROM_HERE, "SELECT " - "notes.block_id, notes.amount," - "notes.nullifier FROM " kNotesTable + "notes.block_id, notes.commitment_tree_position, notes.amount," + "notes.rho, notes.rseed," + "notes.nullifier, notes.addr FROM " kNotesTable " " "LEFT OUTER JOIN spent_notes " "ON notes.nullifier = spent_notes.nullifier AND notes.account_id = " @@ -329,15 +478,31 @@ ZCashOrchardStorage::GetSpendableNotes(mojom::AccountIdPtr account_id) { while (resolve_unspent_notes.Step()) { OrchardNote note; auto block_id = ReadUint32(resolve_unspent_notes, 0); - auto amount = ReadUint32(resolve_unspent_notes, 1); - if (!block_id || !amount) { + auto commitment_tree_position = ReadUint32(resolve_unspent_notes, 1); + auto amount = ReadUint32(resolve_unspent_notes, 2); + if (!block_id || !amount || !commitment_tree_position) { return base::unexpected( Error{ErrorCode::kDbInitError, "Wrong database format"}); } + auto rho = ReadSizedBlob(resolve_unspent_notes, 3); + auto rseed = ReadSizedBlob(resolve_unspent_notes, 4); + auto nf = ReadSizedBlob(resolve_unspent_notes, 5); + auto addr = ReadSizedBlob(resolve_unspent_notes, 6); + + if (!rho.has_value() || !rho.value() || !rseed.has_value() || + !rseed.value() || !nf.has_value() || !nf.value() || !addr.has_value() || + !addr.value()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, "Wrong database format"}); + } + note.block_id = block_id.value(); note.amount = amount.value(); - auto nullifier = resolve_unspent_notes.ColumnBlob(2); - base::ranges::copy(nullifier, note.nullifier.begin()); + note.orchard_commitment_tree_position = commitment_tree_position.value(); + note.rho = **rho; + note.seed = **rseed; + note.nullifier = **nf; + note.addr = **addr; result.push_back(std::move(note)); } return result; @@ -346,7 +511,7 @@ ZCashOrchardStorage::GetSpendableNotes(mojom::AccountIdPtr account_id) { std::optional ZCashOrchardStorage::UpdateNotes( mojom::AccountIdPtr account_id, const std::vector& found_notes, - const std::vector& spent_notes, + const std::vector& found_nullifiers, const uint32_t latest_scanned_block, const std::string& latest_scanned_block_hash) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -364,15 +529,22 @@ std::optional ZCashOrchardStorage::UpdateNotes( // Insert found notes to the notes table sql::Statement statement_populate_notes(database_->GetCachedStatement( SQL_FROM_HERE, "INSERT INTO " kNotesTable " " - "(account_id, amount, block_id, nullifier) " - "VALUES (?, ?, ?, ?);")); + "(account_id, amount, block_id, commitment_tree_position, " + "nullifier, rho, rseed, addr) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?);")); for (const auto& note : found_notes) { statement_populate_notes.Reset(true); statement_populate_notes.BindString(0, account_id->unique_key); statement_populate_notes.BindInt64(1, note.amount); statement_populate_notes.BindInt64(2, note.block_id); - statement_populate_notes.BindBlob(3, note.nullifier); + statement_populate_notes.BindInt64(3, + note.orchard_commitment_tree_position); + statement_populate_notes.BindBlob(4, note.nullifier); + statement_populate_notes.BindBlob(5, note.rho); + statement_populate_notes.BindBlob(6, note.seed); + statement_populate_notes.BindBlob(7, note.addr); + if (!statement_populate_notes.Run()) { return Error{ErrorCode::kFailedToExecuteStatement, database_->GetErrorMessage()}; @@ -385,7 +557,7 @@ std::optional ZCashOrchardStorage::UpdateNotes( "(account_id, spent_block_id, nullifier) " "VALUES (?, ?, ?);")); - for (const auto& spent : spent_notes) { + for (const auto& spent : found_nullifiers) { statement_populate_spent_notes.Reset(true); statement_populate_spent_notes.BindString(0, account_id->unique_key); statement_populate_spent_notes.BindInt64(1, spent.block_id); @@ -420,4 +592,865 @@ std::optional ZCashOrchardStorage::UpdateNotes( return std::nullopt; } +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetLatestShardIndex(mojom::AccountIdPtr account_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement resolve_max_shard_id( + database_->GetCachedStatement(SQL_FROM_HERE, + "SELECT " + "MAX(shard_index) FROM " kShardTree " " + "WHERE account_id = ?;")); + + resolve_max_shard_id.BindString(0, account_id->unique_key); + if (resolve_max_shard_id.Step()) { + if (resolve_max_shard_id.GetColumnType(0) == sql::ColumnType::kNull) { + return std::nullopt; + } + auto shard_index = ReadUint32(resolve_max_shard_id, 0); + if (!shard_index) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + return shard_index.value(); + } + + return std::nullopt; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetCap(mojom::AccountIdPtr account_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement resolve_cap( + database_->GetCachedStatement(SQL_FROM_HERE, + "SELECT " + "cap_data FROM " kShardTreeCap " " + "WHERE account_id = ?;")); + resolve_cap.BindString(0, account_id->unique_key); + + if (!resolve_cap.Step()) { + if (!resolve_cap.Succeeded()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + return std::nullopt; + } + + OrchardCap result; + auto blob = resolve_cap.ColumnBlob(0); + result.data.reserve(blob.size()); + base::ranges::copy(blob, std::back_inserter(result.data)); + + return result; +} + +base::expected ZCashOrchardStorage::PutCap( + mojom::AccountIdPtr account_id, + OrchardCap cap) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + auto existing_cap = GetCap(account_id.Clone()); + if (!existing_cap.has_value()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, existing_cap.error().message}); + } + + sql::Statement stmnt; + if (!existing_cap.value()) { + stmnt.Assign(database_->GetCachedStatement(SQL_FROM_HERE, + "INSERT INTO " kShardTreeCap " " + "(account_id, cap_data) " + "VALUES (?, ?);")); + stmnt.BindString(0, account_id->unique_key); + stmnt.BindBlob(1, cap.data); + } else { + stmnt.Assign(database_->GetCachedStatement( + SQL_FROM_HERE, "UPDATE " kShardTreeCap " " + "SET " + "cap_data = ? WHERE account_id = ?;")); + stmnt.BindBlob(0, cap.data); + stmnt.BindString(1, account_id->unique_key); + } + + if (!stmnt.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + if (!transaction.Commit()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + return true; +} + +base::expected +ZCashOrchardStorage::UpdateSubtreeRoots( + mojom::AccountIdPtr account_id, + uint32_t start_index, + std::vector roots) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement statement_populate_roots(database_->GetCachedStatement( + SQL_FROM_HERE, + "INSERT INTO " kShardTree + " " + "(shard_index, subtree_end_height, root_hash, shard_data, account_id) " + "VALUES (?, ?, ?, ?, ?);" + + )); + + sql::Statement statement_update_roots(database_->GetCachedStatement( + SQL_FROM_HERE, + "UPDATE " kShardTree + " " + "SET subtree_end_height = :subtree_end_height, root_hash = :root_hash " + "WHERE " + "shard_index = :shard_index and account_id = :account_id;")); + + for (size_t i = 0; i < roots.size(); i++) { + if (!roots[i] || + roots[i]->complete_block_hash.size() != kOrchardCompleteBlockHashSize) { + return base::unexpected(Error{ErrorCode::kInternalError, "Wrong data"}); + } + + statement_populate_roots.Reset(true); + statement_populate_roots.BindInt64(0, start_index + i); + statement_populate_roots.BindInt64(1, roots[i]->complete_block_height); + statement_populate_roots.BindBlob(2, roots[i]->complete_block_hash); + statement_populate_roots.BindNull( + 3); // TODO(cypt4): Serialize hash as a leaf + statement_populate_roots.BindString(4, account_id->unique_key); + if (!statement_populate_roots.Run()) { + if (database_->GetErrorCode() == 19 /*SQLITE_CONSTRAINT*/) { + statement_update_roots.Reset(true); + statement_update_roots.BindInt64(0, roots[i]->complete_block_height); + statement_update_roots.BindBlob(1, roots[i]->complete_block_hash); + statement_update_roots.BindInt64(2, start_index + i); + statement_update_roots.BindString(3, account_id->unique_key); + if (!statement_update_roots.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + } else { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + } + } + + if (!transaction.Commit()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + return true; +} + +base::expected +ZCashOrchardStorage::TruncateShards(mojom::AccountIdPtr account_id, + uint32_t shard_index) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement remove_checkpoint_by_id(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kShardTree " " + "WHERE shard_index >= ? AND account_id = ?;")); + + remove_checkpoint_by_id.BindInt64(0, shard_index); + remove_checkpoint_by_id.BindString(1, account_id->unique_key); + + if (!remove_checkpoint_by_id.Run()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + if (!transaction.Commit()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + return true; +} + +base::expected ZCashOrchardStorage::PutShard( + mojom::AccountIdPtr account_id, + OrchardShard shard) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + auto existing_shard = GetShard(account_id.Clone(), shard.address); + if (!existing_shard.has_value()) { + return base::unexpected(existing_shard.error()); + } + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + if (existing_shard.value()) { + sql::Statement statement_update_shard(database_->GetCachedStatement( + SQL_FROM_HERE, + "UPDATE " kShardTree + " " + "SET root_hash = :root_hash, shard_data = :shard_data " + "WHERE shard_index = :shard_index AND account_id = :account_id;")); + + if (!shard.root_hash) { + statement_update_shard.BindNull(0); + } else { + statement_update_shard.BindBlob(0, shard.root_hash.value()); + } + statement_update_shard.BindBlob(1, shard.shard_data); + statement_update_shard.BindInt64(2, shard.address.index); + statement_update_shard.BindString(3, account_id->unique_key); + + if (!statement_update_shard.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + } else { + sql::Statement statement_put_shard(database_->GetCachedStatement( + SQL_FROM_HERE, + "INSERT INTO " kShardTree + " " + "(shard_index, root_hash, shard_data, account_id) " + "VALUES (:shard_index, :root_hash, :shard_data, :account_id);")); + + statement_put_shard.BindInt64(0, shard.address.index); + if (!shard.root_hash) { + statement_put_shard.BindNull(1); + } else { + statement_put_shard.BindBlob(1, shard.root_hash.value()); + } + statement_put_shard.BindBlob(2, shard.shard_data); + statement_put_shard.BindString(3, account_id->unique_key); + + if (!statement_put_shard.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + } + + if (!transaction.Commit()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + return true; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetShard(mojom::AccountIdPtr account_id, + OrchardShardAddress address) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement resolve_shard_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "SELECT root_hash, shard_data FROM " kShardTree " " + "WHERE account_id = ? AND shard_index = ?;")); + + resolve_shard_statement.BindString(0, account_id->unique_key); + resolve_shard_statement.BindInt64(1, address.index); + + if (!resolve_shard_statement.Step()) { + if (!resolve_shard_statement.Succeeded()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + return std::nullopt; + } + + auto hash = ReadRootHash(resolve_shard_statement, 0); + if (!hash.has_value()) { + return base::unexpected(Error{ErrorCode::kDbInitError, hash.error()}); + } + + auto shard_data = resolve_shard_statement.ColumnBlob(1); + auto shard = OrchardShard(address, hash.value(), std::vector()); + + base::ranges::copy(shard_data, std::back_inserter(shard.shard_data)); + + return shard; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::LastShard(mojom::AccountIdPtr account_id, + uint8_t shard_height) { + auto shard_index = GetLatestShardIndex(account_id.Clone()); + if (!shard_index.has_value()) { + return base::unexpected(shard_index.error()); + } + + if (!shard_index.value()) { + return std::nullopt; + } + + return GetShard( + account_id.Clone(), + OrchardShardAddress{shard_height, shard_index.value().value()}); +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetShardRoots(mojom::AccountIdPtr account_id, + uint8_t shard_level) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + std::vector result; + + sql::Statement resolve_shards_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "SELECT shard_index FROM " kShardTree + " WHERE account_id = ? ORDER BY shard_index;")); + + resolve_shards_statement.BindString(0, account_id->unique_key); + + while (resolve_shards_statement.Step()) { + auto shard_index = ReadUint32(resolve_shards_statement, 0); + if (!shard_index) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, ""}); + } + result.push_back(OrchardShardAddress{shard_level, shard_index.value()}); + } + + if (!resolve_shards_statement.is_valid()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, ""}); + } + + return result; +} + +base::expected +ZCashOrchardStorage::AddCheckpoint(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id, + OrchardCheckpoint checkpoint) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement extant_tree_state_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "SELECT position FROM " kShardTreeCheckpoints " " + "WHERE checkpoint_id = ? " + "AND account_id = ?;")); + extant_tree_state_statement.BindInt64(0, checkpoint_id); + extant_tree_state_statement.BindString(1, account_id->unique_key); + + std::optional extant_tree_state_position; + if (extant_tree_state_statement.Step()) { + auto state = ReadCheckpointTreeState(extant_tree_state_statement, 0); + if (!state.has_value()) { + return base::unexpected(Error{ErrorCode::kDbInitError, state.error()}); + } + extant_tree_state_position = state.value(); + } + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, "Failed to init database "}); + } + + // Checkpoint with same id didn't exist. + if (!extant_tree_state_position) { + sql::Statement insert_checkpoint_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "INSERT INTO " kShardTreeCheckpoints " " + "(account_id, checkpoint_id, position)" + "VALUES (?, ?, ?);")); + insert_checkpoint_statement.BindString(0, account_id->unique_key); + insert_checkpoint_statement.BindInt64(1, checkpoint_id); + if (checkpoint.tree_state_position) { + insert_checkpoint_statement.BindInt64( + 2, checkpoint.tree_state_position.value()); + } else { + insert_checkpoint_statement.BindNull(2); + } + if (!insert_checkpoint_statement.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + sql::Statement insert_marks_removed_statement(database_->GetCachedStatement( + SQL_FROM_HERE, "INSERT INTO " kCheckpointsMarksRemoved " " + "(account_id, checkpoint_id, mark_removed_position) " + "VALUES (?, ?, ?);")); + for (const auto& mark : checkpoint.marks_removed) { + insert_marks_removed_statement.Reset(true); + insert_marks_removed_statement.BindString(0, account_id->unique_key); + insert_marks_removed_statement.BindInt64(1, checkpoint_id); + insert_marks_removed_statement.BindInt64(2, mark); + + if (!insert_marks_removed_statement.Run()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + } + } else { + // Existing checkpoint should be the same + if (extant_tree_state_position.value() != checkpoint.tree_state_position) { + return base::unexpected( + Error{ErrorCode::kConsistencyError, "Consistency error"}); + } + auto marks_removed_result = + GetMarksRemoved(account_id.Clone(), checkpoint_id); + if (!marks_removed_result.has_value()) { + return base::unexpected(marks_removed_result.error()); + } + + if (!marks_removed_result.value()) { + return base::unexpected( + Error{ErrorCode::kConsistencyError, "Consistency error"}); + } + + if (marks_removed_result.value().value() != checkpoint.marks_removed) { + return base::unexpected( + Error{ErrorCode::kConsistencyError, "Consistency error"}); + } + } + + if (!transaction.Commit()) { + return base::unexpected(Error{ErrorCode::kDbInitError, ""}); + } + + return true; +} + +base::expected +ZCashOrchardStorage::CheckpointCount(mojom::AccountIdPtr account_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement resolve_checkpoints_count(database_->GetCachedStatement( + SQL_FROM_HERE, + "SELECT COUNT(*) FROM " kShardTreeCheckpoints " WHERE account_id = ?;")); + resolve_checkpoints_count.BindString(0, account_id->unique_key); + if (!resolve_checkpoints_count.Step()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + auto value = ReadUint32(resolve_checkpoints_count, 0); + + if (!value) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + return *value; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::MinCheckpointId(mojom::AccountIdPtr account_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement resolve_min_checkpoint_id(database_->GetCachedStatement( + SQL_FROM_HERE, "SELECT MIN(checkpoint_id) FROM " kShardTreeCheckpoints + " WHERE account_id = ?;")); + + resolve_min_checkpoint_id.BindString(0, account_id->unique_key); + + if (!resolve_min_checkpoint_id.Step()) { + if (!resolve_min_checkpoint_id.Succeeded()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } else { + return std::nullopt; + } + } + + if (resolve_min_checkpoint_id.GetColumnType(0) == sql::ColumnType::kNull) { + return std::nullopt; + } else { + return ReadUint32(resolve_min_checkpoint_id, 0); + } +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::MaxCheckpointId(mojom::AccountIdPtr account_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement resolve_max_checkpoint_id(database_->GetCachedStatement( + SQL_FROM_HERE, "SELECT MAX(checkpoint_id) FROM " kShardTreeCheckpoints + " WHERE account_id = ?;")); + resolve_max_checkpoint_id.BindString(0, account_id->unique_key); + + if (!resolve_max_checkpoint_id.Step()) { + if (!resolve_max_checkpoint_id.Succeeded()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } else { + return std::nullopt; + } + } + + if (resolve_max_checkpoint_id.GetColumnType(0) == sql::ColumnType::kNull) { + return std::nullopt; + } else { + return ReadUint32(resolve_max_checkpoint_id, 0); + } +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetCheckpointAtDepth(mojom::AccountIdPtr account_id, + uint32_t depth) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement get_checkpoint_at_depth_statement( + database_->GetCachedStatement(SQL_FROM_HERE, + "SELECT checkpoint_id, position " + "FROM " kShardTreeCheckpoints " " + "WHERE account_id = ? " + "ORDER BY checkpoint_id DESC " + "LIMIT 1 " + "OFFSET ?;")); + + get_checkpoint_at_depth_statement.BindString(0, account_id->unique_key); + get_checkpoint_at_depth_statement.BindInt64(1, depth); + + if (!get_checkpoint_at_depth_statement.Step()) { + if (!get_checkpoint_at_depth_statement.Succeeded()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + return std::nullopt; + } + + auto value = ReadUint32(get_checkpoint_at_depth_statement, 0); + + if (!value) { + return std::nullopt; + } + + return *value; +} + +base::expected>, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetMarksRemoved(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement get_marks_removed_statement( + database_->GetCachedStatement(SQL_FROM_HERE, + "SELECT mark_removed_position " + "FROM " kCheckpointsMarksRemoved " " + "WHERE checkpoint_id = ? AND " + "account_id = ?;")); + get_marks_removed_statement.BindInt64(0, checkpoint_id); + get_marks_removed_statement.BindString(1, account_id->unique_key); + + std::vector result; + while (get_marks_removed_statement.Step()) { + auto position = ReadUint32(get_marks_removed_statement, 0); + if (!position) { + return base::unexpected(Error{ErrorCode::kDbInitError, "Format error"}); + } + result.push_back(*position); + } + + return result; +} + +base::expected, + ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetCheckpoint(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement get_checkpoint_statement(database_->GetCachedStatement( + SQL_FROM_HERE, + "SELECT position " + "FROM " kShardTreeCheckpoints + " " + "WHERE checkpoint_id = ? AND account_id = ?;")); + + get_checkpoint_statement.BindInt64(0, checkpoint_id); + get_checkpoint_statement.BindString(1, account_id->unique_key); + if (!get_checkpoint_statement.Step()) { + return std::nullopt; + } + auto checkpoint_position = + ReadCheckpointTreeState(get_checkpoint_statement, 0); + if (!checkpoint_position.has_value()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + + sql::Statement marks_removed_statement(database_->GetCachedStatement( + SQL_FROM_HERE, + "SELECT mark_removed_position " + "FROM " kCheckpointsMarksRemoved + " " + "WHERE checkpoint_id = ? AND account_id = ?;")); + + marks_removed_statement.BindInt64(0, checkpoint_id); + marks_removed_statement.BindString(1, account_id->unique_key); + + std::vector positions; + while (marks_removed_statement.Step()) { + auto position = ReadUint32(marks_removed_statement, 0); + if (position) { + positions.push_back(*position); + } else { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, + database_->GetErrorMessage()}); + } + } + + return OrchardCheckpointBundle{ + checkpoint_id, + OrchardCheckpoint{*checkpoint_position, std::move(positions)}}; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetCheckpoints(mojom::AccountIdPtr account_id, + size_t limit) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement get_checkpoints_statement( + database_->GetCachedStatement(SQL_FROM_HERE, + "SELECT checkpoint_id, position " + "FROM " kShardTreeCheckpoints " " + "WHERE account_id = ? " + "ORDER BY position " + "LIMIT ?")); + + get_checkpoints_statement.BindString(0, account_id->unique_key); + get_checkpoints_statement.BindInt64(1, limit); + + std::vector checkpoints; + while (get_checkpoints_statement.Step()) { + auto checkpoint_id = ReadUint32(get_checkpoints_statement, 0); + auto checkpoint_position = + ReadCheckpointTreeState(get_checkpoints_statement, 1); + if (!checkpoint_id) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, ""}); + } + if (!checkpoint_position.has_value()) { + return base::unexpected(Error{ErrorCode::kFailedToExecuteStatement, ""}); + } + auto found_marks_removed = + GetMarksRemoved(account_id.Clone(), *checkpoint_id); + if (!found_marks_removed.has_value()) { + return base::unexpected(found_marks_removed.error()); + } + std::vector marks_removed; + if (found_marks_removed.value()) { + marks_removed = **found_marks_removed; + } + + checkpoints.push_back(OrchardCheckpointBundle{ + *checkpoint_id, OrchardCheckpoint(checkpoint_position.value(), + std::move(marks_removed))}); + } + return checkpoints; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardStorage::GetMaxCheckpointedHeight(mojom::AccountIdPtr account_id, + uint32_t chain_tip_height, + uint32_t min_confirmations) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + uint32_t max_checkpointed_height = chain_tip_height - min_confirmations - 1; + + sql::Statement get_max_checkpointed_height(database_->GetCachedStatement( + SQL_FROM_HERE, "SELECT checkpoint_id FROM " kShardTreeCheckpoints " " + "WHERE checkpoint_id <= ? AND " + "account_id = ? " + "ORDER BY checkpoint_id DESC " + "LIMIT 1")); + + get_max_checkpointed_height.BindInt64(0, max_checkpointed_height); + get_max_checkpointed_height.BindString(1, account_id->unique_key); + + if (!get_max_checkpointed_height.Step()) { + if (!get_max_checkpointed_height.Succeeded()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } else { + return std::nullopt; + } + } + + auto value = ReadUint32(get_max_checkpointed_height, 0); + + if (!value) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + return *value; +} + +base::expected +ZCashOrchardStorage::RemoveCheckpoint(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Statement remove_checkpoint_by_id(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kShardTreeCheckpoints " " + "WHERE checkpoint_id = ? AND account_id= ?;")); + + remove_checkpoint_by_id.BindInt64(0, checkpoint_id); + remove_checkpoint_by_id.BindString(1, account_id->unique_key); + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, "Failed to init database "}); + } + + if (!remove_checkpoint_by_id.Run()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + if (!transaction.Commit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, "Failed to init database "}); + } + + return true; +} + +base::expected +ZCashOrchardStorage::TruncateCheckpoints(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!EnsureDbInit()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, database_->GetErrorMessage()}); + } + + sql::Transaction transaction(database_.get()); + if (!transaction.Begin()) { + return base::unexpected( + Error{ErrorCode::kDbInitError, "Failed to init database "}); + } + + sql::Statement truncate_checkpoints(database_->GetCachedStatement( + SQL_FROM_HERE, "DELETE FROM " kShardTreeCheckpoints + " WHERE checkpoint_id >= ? and account_id = ?;")); + + truncate_checkpoints.BindInt64(0, checkpoint_id); + truncate_checkpoints.BindString(1, account_id->unique_key); + + if (!truncate_checkpoints.Run()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + if (!transaction.Commit()) { + return base::unexpected( + Error{ErrorCode::kNoCheckpoints, database_->GetErrorMessage()}); + } + + return true; +} + } // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/zcash_orchard_storage.h b/components/brave_wallet/browser/zcash/zcash_orchard_storage.h index 6594af6ca9b7..8789dde494eb 100644 --- a/components/brave_wallet/browser/zcash/zcash_orchard_storage.h +++ b/components/brave_wallet/browser/zcash/zcash_orchard_storage.h @@ -17,6 +17,8 @@ #include "base/types/expected.h" #include "brave/components/brave_wallet/common/brave_wallet.mojom.h" #include "brave/components/brave_wallet/common/zcash_utils.h" +#include "brave/components/services/brave_wallet/public/mojom/zcash_decoder.mojom.h" +#include "sql/statement.h" namespace sql { class Database; @@ -24,21 +26,58 @@ class Database; namespace brave_wallet { +template +base::expected>, std::string> +ReadSizedBlob(sql::Statement& statement, size_t position) { + auto columns = statement.ColumnCount(); + CHECK(columns >= 0); + if (position >= static_cast(columns)) { + return base::unexpected("Position mismatch"); + } + + if (statement.GetColumnType(position) == sql::ColumnType::kNull) { + return std::nullopt; + } + + if (statement.GetColumnType(position) != sql::ColumnType::kBlob) { + return base::unexpected("Type mismatch"); + } + + auto blob = statement.ColumnBlob(position); + if (blob.size() != T) { + return base::unexpected("Size mismatch"); + } + + std::array to; + base::ranges::copy_n(blob.begin(), to.size(), to.begin()); + return to; +} + // Implements SQLite database to store found incoming notes, // nullifiers, wallet zcash accounts and commitment trees. -class ZCashOrchardStorage { +class ZCashOrchardStorage : public base::RefCounted { public: + using WithCheckpointsCallback = + base::RepeatingCallback; + struct AccountMeta { + AccountMeta(); + ~AccountMeta(); + AccountMeta(const AccountMeta&); + AccountMeta& operator=(const AccountMeta&); uint32_t account_birthday = 0; - uint32_t latest_scanned_block_id = 0; - std::string latest_scanned_block_hash; + std::optional latest_scanned_block_id; + std::optional latest_scanned_block_hash; }; enum class ErrorCode { kDbInitError, kAccountNotFound, kFailedToExecuteStatement, - kInternalError + kFailedToCommitTransaction, + kInternalError, + kNoCheckpoints, + kConsistencyError }; struct Error { @@ -47,14 +86,14 @@ class ZCashOrchardStorage { }; explicit ZCashOrchardStorage(base::FilePath path_to_database); - ~ZCashOrchardStorage(); base::expected RegisterAccount( mojom::AccountIdPtr account_id, - uint32_t account_birthday_block, - const std::string& account_bithday_block_hash); + uint32_t account_birthday_block); base::expected GetAccountMeta( mojom::AccountIdPtr account_id); + base::expected ResetAccountSyncState( + mojom::AccountIdPtr account_id); // Removes database records which are under effect of chain reorg // Removes spendable notes and nullifiers with block_height > reorg_block @@ -67,19 +106,80 @@ class ZCashOrchardStorage { base::expected, ZCashOrchardStorage::Error> GetSpendableNotes(mojom::AccountIdPtr account_id); // Returns a list of discovered nullifiers - base::expected, Error> GetNullifiers( + base::expected, Error> GetNullifiers( mojom::AccountIdPtr account_id); // Updates database with discovered spendable notes and nullifiers // Also updates account info with latest scanned block info std::optional UpdateNotes( mojom::AccountIdPtr account_id, const std::vector& notes_to_add, - const std::vector& notes_to_delete, + const std::vector& found_nullifiers, const uint32_t latest_scanned_block, const std::string& latest_scanned_block_hash); void ResetDatabase(); + // Shard tree + base::expected, Error> GetCap( + mojom::AccountIdPtr account_id); + base::expected PutCap(mojom::AccountIdPtr account_id, + OrchardCap cap); + + base::expected TruncateShards(mojom::AccountIdPtr account_id, + uint32_t shard_index); + base::expected, Error> GetLatestShardIndex( + mojom::AccountIdPtr account_id); + base::expected PutShard(mojom::AccountIdPtr account_id, + OrchardShard shard); + base::expected, Error> GetShard( + mojom::AccountIdPtr account_id, + OrchardShardAddress address); + base::expected, Error> LastShard( + mojom::AccountIdPtr account_id, + uint8_t shard_height); + + base::expected CheckpointCount(mojom::AccountIdPtr account_id); + base::expected, Error> MinCheckpointId( + mojom::AccountIdPtr account_id); + base::expected, Error> MaxCheckpointId( + mojom::AccountIdPtr account_id); + base::expected, Error> GetCheckpointAtDepth( + mojom::AccountIdPtr account_id, + uint32_t depth); + base::expected, Error> GetMaxCheckpointedHeight( + mojom::AccountIdPtr account_id, + uint32_t chain_tip_height, + uint32_t min_confirmations); + base::expected RemoveCheckpoint(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id); + base::expected TruncateCheckpoints( + mojom::AccountIdPtr account_id, + uint32_t checkpoint_id); + base::expected AddCheckpoint(mojom::AccountIdPtr account_id, + uint32_t checkpoint_id, + OrchardCheckpoint checkpoint); + base::expected, Error> GetCheckpoints( + mojom::AccountIdPtr account_id, + size_t limit); + base::expected, Error> GetCheckpoint( + mojom::AccountIdPtr account_id, + uint32_t checkpoint_id); + base::expected>, Error> GetMarksRemoved( + mojom::AccountIdPtr account_id, + uint32_t checkpoint_id); + + base::expected UpdateSubtreeRoots( + mojom::AccountIdPtr account_id, + uint32_t start_index, + std::vector roots); + base::expected, Error> GetShardRoots( + mojom::AccountIdPtr account_id, + uint8_t shard_level); + private: + friend class base::RefCounted; + + ~ZCashOrchardStorage(); + bool EnsureDbInit(); bool CreateOrUpdateDatabase(); bool CreateSchema(); diff --git a/components/brave_wallet/browser/zcash/zcash_orchard_storage_unittest.cc b/components/brave_wallet/browser/zcash/zcash_orchard_storage_unittest.cc index 724743447e2e..07aaf04e6919 100644 --- a/components/brave_wallet/browser/zcash/zcash_orchard_storage_unittest.cc +++ b/components/brave_wallet/browser/zcash/zcash_orchard_storage_unittest.cc @@ -25,14 +25,14 @@ class OrchardStorageTest : public testing::Test { base::test::TaskEnvironment task_environment_; base::ScopedTempDir temp_dir_; - std::unique_ptr orchard_storage_; + scoped_refptr orchard_storage_; }; void OrchardStorageTest::SetUp() { ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); base::FilePath db_path( temp_dir_.GetPath().Append(FILE_PATH_LITERAL("orchard.db"))); - orchard_storage_ = std::make_unique(db_path); + orchard_storage_ = base::WrapRefCounted(new ZCashOrchardStorage(db_path)); } TEST_F(OrchardStorageTest, AccountMeta) { @@ -49,36 +49,33 @@ TEST_F(OrchardStorageTest, AccountMeta) { } EXPECT_TRUE( - orchard_storage_->RegisterAccount(account_id_1.Clone(), 100, "hash") - .has_value()); + orchard_storage_->RegisterAccount(account_id_1.Clone(), 100).has_value()); { auto result = orchard_storage_->GetAccountMeta(account_id_1.Clone()); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->account_birthday, 100u); - EXPECT_EQ(result->latest_scanned_block_id, 100u); - EXPECT_EQ(result->latest_scanned_block_hash, "hash"); + EXPECT_FALSE(result->latest_scanned_block_id); + EXPECT_FALSE(result->latest_scanned_block_hash); } { // Failed to insert same account - EXPECT_EQ( - orchard_storage_->RegisterAccount(account_id_1.Clone(), 200, "hash") - .error() - .error_code, - ZCashOrchardStorage::ErrorCode::kFailedToExecuteStatement); + EXPECT_EQ(orchard_storage_->RegisterAccount(account_id_1.Clone(), 200) + .error() + .error_code, + ZCashOrchardStorage::ErrorCode::kFailedToExecuteStatement); } // Insert second account EXPECT_TRUE( - orchard_storage_->RegisterAccount(account_id_2.Clone(), 200, "hash") - .has_value()); + orchard_storage_->RegisterAccount(account_id_2.Clone(), 200).has_value()); { auto result = orchard_storage_->GetAccountMeta(account_id_2.Clone()); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->account_birthday, 200u); - EXPECT_EQ(result->latest_scanned_block_id, 200u); - EXPECT_EQ(result->latest_scanned_block_hash, "hash"); + EXPECT_FALSE(result->latest_scanned_block_id); + EXPECT_FALSE(result->latest_scanned_block_hash); } } @@ -91,11 +88,9 @@ TEST_F(OrchardStorageTest, PutDiscoveredNotes) { mojom::AccountKind::kDerived, 1); EXPECT_TRUE( - orchard_storage_->RegisterAccount(account_id_1.Clone(), 100, "hash") - .has_value()); + orchard_storage_->RegisterAccount(account_id_1.Clone(), 100).has_value()); EXPECT_TRUE( - orchard_storage_->RegisterAccount(account_id_2.Clone(), 100, "hash") - .has_value()); + orchard_storage_->RegisterAccount(account_id_2.Clone(), 100).has_value()); // Update notes for account 1 { @@ -147,27 +142,30 @@ TEST_F(OrchardStorageTest, PutDiscoveredNotes) { // Update notes for account 1 { std::vector notes; - std::vector nullifiers; + std::vector spends; // Add 1 note, spend 1 note notes.push_back(GenerateMockOrchardNote(account_id_1, 201, 3)); - nullifiers.push_back(GenerateMockNullifier(account_id_1, 203, 1)); + spends.push_back( + OrchardNoteSpend{203, GenerateMockNullifier(account_id_1, 1)}); - orchard_storage_->UpdateNotes(account_id_1.Clone(), notes, nullifiers, 300, + orchard_storage_->UpdateNotes(account_id_1.Clone(), notes, spends, 300, "hash300"); } // Update notes for account 2 { std::vector notes; - std::vector nullifiers; + std::vector spends; // Add 1 note, spend 2 notes notes.push_back(GenerateMockOrchardNote(account_id_2, 211, 4)); - nullifiers.push_back(GenerateMockNullifier(account_id_2, 222, 2)); - nullifiers.push_back(GenerateMockNullifier(account_id_2, 233, 3)); + spends.push_back( + OrchardNoteSpend{222, GenerateMockNullifier(account_id_2, 2)}); + spends.push_back( + OrchardNoteSpend{233, GenerateMockNullifier(account_id_2, 3)}); - orchard_storage_->UpdateNotes(account_id_2.Clone(), notes, nullifiers, 300, + orchard_storage_->UpdateNotes(account_id_2.Clone(), notes, spends, 300, "hash300"); } @@ -218,16 +216,14 @@ TEST_F(OrchardStorageTest, HandleChainReorg) { mojom::AccountKind::kDerived, 1); EXPECT_TRUE( - orchard_storage_->RegisterAccount(account_id_1.Clone(), 100, "hash") - .has_value()); + orchard_storage_->RegisterAccount(account_id_1.Clone(), 100).has_value()); EXPECT_TRUE( - orchard_storage_->RegisterAccount(account_id_2.Clone(), 100, "hash") - .has_value()); + orchard_storage_->RegisterAccount(account_id_2.Clone(), 100).has_value()); // Update notes for account 1 { std::vector notes; - std::vector nullifiers; + std::vector spends; // Add 4 notes, spend 2 notes notes.push_back(GenerateMockOrchardNote(account_id_1, 101, 1)); @@ -236,17 +232,19 @@ TEST_F(OrchardStorageTest, HandleChainReorg) { notes.push_back(GenerateMockOrchardNote(account_id_1, 104, 4)); notes.push_back(GenerateMockOrchardNote(account_id_1, 304, 5)); - nullifiers.push_back(GenerateMockNullifier(account_id_1, 102, 2)); - nullifiers.push_back(GenerateMockNullifier(account_id_1, 103, 3)); + spends.push_back( + OrchardNoteSpend{102, GenerateMockNullifier(account_id_1, 2)}); + spends.push_back( + OrchardNoteSpend{103, GenerateMockNullifier(account_id_1, 3)}); - orchard_storage_->UpdateNotes(account_id_1.Clone(), notes, nullifiers, 450, + orchard_storage_->UpdateNotes(account_id_1.Clone(), notes, spends, 450, "hash450"); } // Update notes for account 2 { std::vector notes; - std::vector nullifiers; + std::vector spends; // Add 4 notes, spend 2 notes notes.push_back(GenerateMockOrchardNote(account_id_2, 211, 1)); @@ -254,10 +252,12 @@ TEST_F(OrchardStorageTest, HandleChainReorg) { notes.push_back(GenerateMockOrchardNote(account_id_2, 213, 3)); notes.push_back(GenerateMockOrchardNote(account_id_2, 414, 4)); - nullifiers.push_back(GenerateMockNullifier(account_id_2, 322, 2)); - nullifiers.push_back(GenerateMockNullifier(account_id_2, 333, 3)); + spends.push_back( + OrchardNoteSpend{322, GenerateMockNullifier(account_id_2, 2)}); + spends.push_back( + OrchardNoteSpend{333, GenerateMockNullifier(account_id_2, 3)}); - orchard_storage_->UpdateNotes(account_id_2.Clone(), notes, nullifiers, 500, + orchard_storage_->UpdateNotes(account_id_2.Clone(), notes, spends, 500, "hash500"); } @@ -350,4 +350,536 @@ TEST_F(OrchardStorageTest, HandleChainReorg) { } } +TEST_F(OrchardStorageTest, Shards) {} + +namespace { + +zcash::mojom::SubtreeRootPtr CreateSubtreeRoot(size_t level, size_t index) { + zcash::mojom::SubtreeRootPtr root = zcash::mojom::SubtreeRoot::New(); + root->root_hash = std::vector(kOrchardShardTreeHashSize, index); + root->complete_block_hash = + std::vector(kOrchardCompleteBlockHashSize, index); + root->complete_block_height = 0; + return root; +} + +OrchardShard CreateShard(size_t index, size_t level) { + OrchardShard orchard_shard; + orchard_shard.root_hash = OrchardShardRootHash(); + orchard_shard.root_hash->fill(static_cast(index)); + orchard_shard.address.index = index; + orchard_shard.address.level = level; + orchard_shard.shard_data = std::vector({0, 0, 0, 0}); + return orchard_shard; +} + +} // namespace + +TEST_F(OrchardStorageTest, InsertSubtreeRoots_BlockHashConflict) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + std::vector level_1_roots; + level_1_roots.push_back(CreateSubtreeRoot(9, 0)); + level_1_roots.push_back(CreateSubtreeRoot(9, 0)); + EXPECT_FALSE( + orchard_storage_ + ->UpdateSubtreeRoots(account_id.Clone(), 0, std::move(level_1_roots)) + .has_value()); +} + +TEST_F(OrchardStorageTest, InsertSubtreeRoots) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + { + std::vector level_1_roots; + for (uint32_t i = 0; i < 10; i++) { + level_1_roots.push_back(CreateSubtreeRoot(9, i)); + } + EXPECT_TRUE(orchard_storage_ + ->UpdateSubtreeRoots(account_id.Clone(), 0, + std::move(level_1_roots)) + .value()); + } + + { + std::vector level_1_addrs; + for (uint32_t i = 0; i < 10; i++) { + level_1_addrs.push_back(OrchardShardAddress{9, i}); + } + auto result = orchard_storage_->GetShardRoots(account_id.Clone(), 9); + + EXPECT_EQ(result.value(), level_1_addrs); + } +} + +TEST_F(OrchardStorageTest, TruncateSubtreeRoots) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + { + std::vector level_1_roots; + for (int i = 0; i < 10; i++) { + level_1_roots.push_back(CreateSubtreeRoot(1, i)); + } + EXPECT_TRUE(orchard_storage_ + ->UpdateSubtreeRoots(account_id.Clone(), 0, + std::move(level_1_roots)) + .value()); + } + + EXPECT_TRUE(orchard_storage_->TruncateShards(account_id.Clone(), 5).value()); + { + std::vector addresses_after_truncate; + for (uint32_t i = 0; i < 5; i++) { + addresses_after_truncate.push_back(OrchardShardAddress{1, i}); + } + auto result = orchard_storage_->GetShardRoots(account_id.Clone(), 1); + EXPECT_EQ(result.value(), addresses_after_truncate); + } +} + +TEST_F(OrchardStorageTest, TruncateShards) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + { + for (uint32_t i = 0; i < 10; i++) { + EXPECT_TRUE( + orchard_storage_->PutShard(account_id.Clone(), CreateShard(i, 1)) + .value()); + } + } + + EXPECT_TRUE(orchard_storage_->TruncateShards(account_id.Clone(), 5).value()); + for (uint32_t i = 0; i < 5; i++) { + EXPECT_EQ(CreateShard(i, 1), + **(orchard_storage_->GetShard(account_id.Clone(), + OrchardShardAddress(1, i)))); + } + + EXPECT_EQ(std::nullopt, *(orchard_storage_->GetShard( + account_id.Clone(), OrchardShardAddress(1, 6)))); +} + +TEST_F(OrchardStorageTest, ShardOverridesSubtreeRoot) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + { + std::vector level_1_roots; + for (uint32_t i = 0; i < 10; i++) { + level_1_roots.push_back(CreateSubtreeRoot(1, i)); + } + EXPECT_TRUE(orchard_storage_ + ->UpdateSubtreeRoots(account_id.Clone(), 0, + std::move(level_1_roots)) + .value()); + } + + // Update existing shard + OrchardShard new_shard; + new_shard.root_hash = OrchardShardRootHash(); + new_shard.address.index = 5; + new_shard.address.level = 1; + new_shard.root_hash->fill(5); + new_shard.shard_data = std::vector({5, 5, 5, 5}); + EXPECT_TRUE( + orchard_storage_->PutShard(account_id.Clone(), new_shard).value()); + + auto result = + orchard_storage_->GetShard(account_id.Clone(), OrchardShardAddress{1, 5}); + EXPECT_EQ(*result.value(), new_shard); +} + +TEST_F(OrchardStorageTest, InsertShards) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + EXPECT_EQ(std::nullopt, + orchard_storage_->GetLatestShardIndex(account_id.Clone()).value()); + EXPECT_EQ( + std::nullopt, + orchard_storage_->GetShard(account_id.Clone(), OrchardShardAddress{1, 0}) + .value()); + EXPECT_EQ(std::nullopt, + orchard_storage_->LastShard(account_id.Clone(), 1).value()); + + { + std::vector level_1_roots; + for (uint32_t i = 0; i < 10; i++) { + level_1_roots.push_back(CreateSubtreeRoot(1, i)); + } + EXPECT_TRUE(orchard_storage_ + ->UpdateSubtreeRoots(account_id.Clone(), 0, + std::move(level_1_roots)) + .value()); + } + + OrchardShard new_shard; + new_shard.root_hash = OrchardShardRootHash(); + new_shard.address.index = 11; + new_shard.address.level = 1; + new_shard.root_hash->fill(11); + new_shard.shard_data = std::vector({1, 1, 1, 1}); + + EXPECT_TRUE( + orchard_storage_->PutShard(account_id.Clone(), new_shard).value()); + + { + auto result = orchard_storage_->GetShard(account_id.Clone(), + OrchardShardAddress{1, 11}); + EXPECT_EQ(*result.value(), new_shard); + } + + { + for (uint32_t i = 0; i < 10; i++) { + auto result = orchard_storage_->GetShard(account_id.Clone(), + OrchardShardAddress{1, i}); + auto root = CreateSubtreeRoot(1, i); + EXPECT_EQ(std::vector(std::begin(*result.value()->root_hash), + std::end(*result.value()->root_hash)), + root->root_hash); + } + } + + EXPECT_EQ(11u, orchard_storage_->GetLatestShardIndex(account_id.Clone()) + .value() + .value()); + EXPECT_EQ(new_shard, + orchard_storage_->LastShard(account_id.Clone(), 1).value()); +} + +TEST_F(OrchardStorageTest, RemoveChekpoint) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + + OrchardCheckpoint checkpoint2; + checkpoint2.marks_removed = std::vector({4, 5, 6}); + checkpoint2.tree_state_position = std::nullopt; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint2) + .value()); + + EXPECT_TRUE( + orchard_storage_->RemoveCheckpoint(account_id.Clone(), 1).value()); + EXPECT_EQ(std::nullopt, + orchard_storage_->GetCheckpoint(account_id.Clone(), 1).value()); + EXPECT_EQ( + OrchardCheckpointBundle(2, checkpoint2), + orchard_storage_->GetCheckpoint(account_id.Clone(), 2).value().value()); +} + +TEST_F(OrchardStorageTest, CheckpointId) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + EXPECT_EQ(std::nullopt, + orchard_storage_->MinCheckpointId(account_id.Clone()).value()); + EXPECT_EQ(std::nullopt, + orchard_storage_->MaxCheckpointId(account_id.Clone()).value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + + OrchardCheckpoint checkpoint2; + checkpoint2.marks_removed = std::vector({1, 2, 3}); + checkpoint2.tree_state_position = 2; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint2) + .value()); + + OrchardCheckpoint checkpoint3; + checkpoint3.marks_removed = std::vector({5}); + checkpoint3.tree_state_position = 3; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 3, checkpoint3) + .value()); + + OrchardCheckpoint checkpoint4; + checkpoint4.marks_removed = std::vector(); + checkpoint4.tree_state_position = std::nullopt; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 4, checkpoint4) + .value()); + + EXPECT_EQ(1, orchard_storage_->MinCheckpointId(account_id.Clone()).value()); + EXPECT_EQ(4, orchard_storage_->MaxCheckpointId(account_id.Clone()).value()); +} + +TEST_F(OrchardStorageTest, CheckpointAtPosition) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + OrchardCheckpoint checkpoint2; + checkpoint2.marks_removed = std::vector({4, 5, 6}); + checkpoint2.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint2) + .value()); + OrchardCheckpoint checkpoint3; + checkpoint3.marks_removed = std::vector({7, 8, 9}); + checkpoint3.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 3, checkpoint3) + .value()); + + EXPECT_EQ(1u, orchard_storage_->GetCheckpointAtDepth(account_id.Clone(), 2) + .value() + .value()); + EXPECT_EQ( + std::nullopt, + orchard_storage_->GetCheckpointAtDepth(account_id.Clone(), 5).value()); +} + +TEST_F(OrchardStorageTest, TruncateCheckpoints_OutOfBoundry) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + + EXPECT_TRUE( + orchard_storage_->TruncateCheckpoints(account_id.Clone(), 3).value()); + + EXPECT_EQ( + OrchardCheckpointBundle(1, checkpoint1), + orchard_storage_->GetCheckpoint(account_id.Clone(), 1).value().value()); +} + +TEST_F(OrchardStorageTest, TruncateCheckpoints) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + + OrchardCheckpoint checkpoint2; + checkpoint2.marks_removed = std::vector({1, 2, 3}); + checkpoint2.tree_state_position = 2; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint2) + .value()); + + OrchardCheckpoint checkpoint3; + checkpoint3.marks_removed = std::vector({5}); + checkpoint3.tree_state_position = 3; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 3, checkpoint3) + .value()); + + OrchardCheckpoint checkpoint4; + checkpoint4.marks_removed = std::vector(); + checkpoint4.tree_state_position = std::nullopt; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 4, checkpoint4) + .value()); + + EXPECT_TRUE( + orchard_storage_->TruncateCheckpoints(account_id.Clone(), 3).value()); + + EXPECT_EQ( + OrchardCheckpointBundle(1, checkpoint1), + orchard_storage_->GetCheckpoint(account_id.Clone(), 1).value().value()); + EXPECT_EQ( + OrchardCheckpointBundle(2, checkpoint2), + orchard_storage_->GetCheckpoint(account_id.Clone(), 2).value().value()); + EXPECT_EQ(std::nullopt, + orchard_storage_->GetCheckpoint(account_id.Clone(), 3).value()); + EXPECT_EQ(std::nullopt, + orchard_storage_->GetCheckpoint(account_id.Clone(), 4).value()); +} + +TEST_F(OrchardStorageTest, AddCheckpoint) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + OrchardCheckpoint checkpoint2; + checkpoint2.marks_removed = std::vector({4, 5, 6}); + checkpoint2.tree_state_position = std::nullopt; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint2) + .value()); + OrchardCheckpoint checkpoint3; + checkpoint3.marks_removed = std::vector(); + checkpoint3.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 3, checkpoint3) + .value()); + + EXPECT_EQ( + OrchardCheckpointBundle(1, checkpoint1), + orchard_storage_->GetCheckpoint(account_id.Clone(), 1).value().value()); + EXPECT_EQ( + OrchardCheckpointBundle(2, checkpoint2), + orchard_storage_->GetCheckpoint(account_id.Clone(), 2).value().value()); + EXPECT_EQ( + OrchardCheckpointBundle(3, checkpoint3), + orchard_storage_->GetCheckpoint(account_id.Clone(), 3).value().value()); +} + +TEST_F(OrchardStorageTest, AddSameCheckpoint) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + { + OrchardCheckpoint checkpoint; + checkpoint.marks_removed = std::vector({1, 2, 3}); + checkpoint.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint) + .value()); + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint) + .value()); + + EXPECT_EQ( + OrchardCheckpointBundle(1, checkpoint), + orchard_storage_->GetCheckpoint(account_id.Clone(), 1).value().value()); + } + + { + OrchardCheckpoint checkpoint; + checkpoint.marks_removed = std::vector({1, 2, 3}); + checkpoint.tree_state_position = std::nullopt; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint) + .value()); + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 2, checkpoint) + .value()); + + EXPECT_EQ( + OrchardCheckpointBundle(2, checkpoint), + orchard_storage_->GetCheckpoint(account_id.Clone(), 2).value().value()); + } + + { + OrchardCheckpoint checkpoint; + checkpoint.marks_removed = std::vector(); + checkpoint.tree_state_position = std::nullopt; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 3, checkpoint) + .value()); + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 3, checkpoint) + .value()); + + EXPECT_EQ( + OrchardCheckpointBundle(3, checkpoint), + orchard_storage_->GetCheckpoint(account_id.Clone(), 3).value().value()); + } +} + +TEST_F(OrchardStorageTest, AddChekpoint_ErrorOnConflict) { + auto account_id = MakeIndexBasedAccountId(mojom::CoinType::ZEC, + mojom::KeyringId::kZCashMainnet, + mojom::AccountKind::kDerived, 0); + EXPECT_TRUE( + orchard_storage_->RegisterAccount(account_id.Clone(), 100).has_value()); + + OrchardCheckpoint checkpoint1; + checkpoint1.marks_removed = std::vector({1, 2, 3}); + checkpoint1.tree_state_position = 4; + EXPECT_TRUE( + orchard_storage_->AddCheckpoint(account_id.Clone(), 1, checkpoint1) + .value()); + + OrchardCheckpoint checkpoint_different_marks_removed = checkpoint1; + checkpoint_different_marks_removed.marks_removed = + std::vector({1, 2}); + EXPECT_FALSE(orchard_storage_ + ->AddCheckpoint(account_id.Clone(), 1, + checkpoint_different_marks_removed) + .has_value()); + + OrchardCheckpoint checkpoint_different_position1 = checkpoint1; + checkpoint_different_position1.tree_state_position = 7; + EXPECT_FALSE( + orchard_storage_ + ->AddCheckpoint(account_id.Clone(), 1, checkpoint_different_position1) + .has_value()); + + OrchardCheckpoint checkpoint_different_position2 = checkpoint1; + checkpoint_different_position2.tree_state_position = std::nullopt; + EXPECT_FALSE( + orchard_storage_ + ->AddCheckpoint(account_id.Clone(), 1, checkpoint_different_position2) + .has_value()); + + EXPECT_EQ( + OrchardCheckpointBundle(1, checkpoint1), + orchard_storage_->GetCheckpoint(account_id.Clone(), 1).value().value()); +} + } // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.cc b/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.cc new file mode 100644 index 000000000000..dd26f55ab0da --- /dev/null +++ b/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.cc @@ -0,0 +1,157 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h" + +#include + +#include "base/check_is_test.h" +#include "base/containers/extend.h" +#include "brave/components/brave_wallet/browser/zcash/orchard_shard_tree_delegate_impl.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet { + +ZCashOrchardSyncState::ZCashOrchardSyncState(base::FilePath path_to_database) { + storage_ = base::MakeRefCounted(path_to_database); +} + +ZCashOrchardSyncState::~ZCashOrchardSyncState() {} + +// static +void ZCashOrchardSyncState::OverrideShardTreeManagerForTesting( + const mojom::AccountIdPtr& account_id, + std::unique_ptr manager) { + CHECK_IS_TEST(); + shard_tree_managers_[account_id.Clone()] = + OrchardShardTreeManager::CreateForTesting( + std::make_unique(account_id.Clone(), + storage_)); +} + +OrchardShardTreeManager* ZCashOrchardSyncState::GetOrCreateShardTreeManager( + const mojom::AccountIdPtr& account_id) { + if (shard_tree_managers_.find(account_id) == shard_tree_managers_.end()) { + shard_tree_managers_[account_id.Clone()] = OrchardShardTreeManager::Create( + std::make_unique(account_id.Clone(), + storage_)); + } + return shard_tree_managers_[account_id.Clone()].get(); +} + +base::expected +ZCashOrchardSyncState::RegisterAccount(mojom::AccountIdPtr account_id, + uint64_t account_birthday_block) { + return storage_->RegisterAccount(std::move(account_id), + account_birthday_block); +} + +base::expected +ZCashOrchardSyncState::GetAccountMeta(mojom::AccountIdPtr account_id) { + return storage_->GetAccountMeta(std::move(account_id)); +} + +std::optional +ZCashOrchardSyncState::HandleChainReorg(mojom::AccountIdPtr account_id, + uint32_t reorg_block_id, + const std::string& reorg_block_hash) { + return storage_->HandleChainReorg(std::move(account_id), reorg_block_id, + reorg_block_hash); +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardSyncState::GetSpendableNotes(mojom::AccountIdPtr account_id) { + return storage_->GetSpendableNotes(std::move(account_id)); +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardSyncState::GetNullifiers(mojom::AccountIdPtr account_id) { + return storage_->GetNullifiers(std::move(account_id)); +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardSyncState::GetLatestShardIndex(mojom::AccountIdPtr account_id) { + return storage_->GetLatestShardIndex(std::move(account_id)); +} + +base::expected +ZCashOrchardSyncState::UpdateSubtreeRoots( + mojom::AccountIdPtr account_id, + uint32_t start_index, + std::vector roots) { + // return storage_->UpdateSubtreeRoots(std::move(account_id), start_index, + // std::move(roots)); + return true; +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardSyncState::GetMaxCheckpointedHeight(mojom::AccountIdPtr account_id, + uint32_t chain_tip_height, + size_t min_confirmations) { + return storage_->GetMaxCheckpointedHeight( + account_id.Clone(), chain_tip_height, min_confirmations); +} + +std::optional ZCashOrchardSyncState::UpdateNotes( + mojom::AccountIdPtr account_id, + OrchardBlockScanner::Result block_scanner_results, + const uint32_t latest_scanned_block, + const std::string& latest_scanned_block_hash) { + auto existing_notes = storage_->GetSpendableNotes(account_id.Clone()); + if (!existing_notes.has_value()) { + return existing_notes.error(); + } + + std::vector notes_to_add = + block_scanner_results.discovered_notes; + base::Extend(existing_notes.value(), notes_to_add); + + std::vector nf_to_add; + + for (const auto& nf : block_scanner_results.found_spends) { + if (std::find_if(existing_notes.value().begin(), + existing_notes.value().end(), [&nf](const auto& v) { + return v.nullifier == nf.nullifier; + }) != existing_notes.value().end()) { + nf_to_add.push_back(nf); + } + } + + if (!GetOrCreateShardTreeManager(account_id.Clone()) + ->InsertCommitments(std::move(block_scanner_results))) { + return base::unexpected(ZCashOrchardStorage::Error{ + ZCashOrchardStorage::ErrorCode::kInternalError, + "Failed to insert commitments"}); + } + + return storage_->UpdateNotes(std::move(account_id), notes_to_add, + std::move(nf_to_add), latest_scanned_block, + latest_scanned_block_hash); +} + +base::expected +ZCashOrchardSyncState::ResetAccountSyncState(mojom::AccountIdPtr account_id) { + return storage_->ResetAccountSyncState(std::move(account_id)); +} + +base::expected, ZCashOrchardStorage::Error> +ZCashOrchardSyncState::CalculateWitnessForCheckpoint( + mojom::AccountIdPtr account_id, + std::vector notes, + uint32_t checkpoint_position) { + auto result = GetOrCreateShardTreeManager(account_id.Clone()) + ->CalculateWitness(notes, checkpoint_position); + if (!result.has_value()) { + return base::unexpected(ZCashOrchardStorage::Error{ + ZCashOrchardStorage::ErrorCode::kConsistencyError, result.error()}); + } + return result.value(); +} + +void ZCashOrchardSyncState::ResetDatabase() { + storage_->ResetDatabase(); +} + +} // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h b/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h new file mode 100644 index 000000000000..7a5d75486d22 --- /dev/null +++ b/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ZCASH_ORCHARD_SYNC_STATE_H_ +#define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ZCASH_ORCHARD_SYNC_STATE_H_ + +#include +#include +#include +#include + +#include "brave/components/brave_wallet/browser/internal/orchard_shard_tree_manager.h" +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_storage.h" +#include "brave/components/brave_wallet/common/zcash_utils.h" + +namespace brave_wallet { + +class ZCashOrchardSyncState { + public: + explicit ZCashOrchardSyncState(base::FilePath path_to_database); + ~ZCashOrchardSyncState(); + + base::expected + RegisterAccount(mojom::AccountIdPtr account_id, + uint64_t account_birthday_block); + + base::expected + GetAccountMeta(mojom::AccountIdPtr account_id); + + std::optional HandleChainReorg( + mojom::AccountIdPtr account_id, + uint32_t reorg_block_id, + const std::string& reorg_block_hash); + + base::expected, ZCashOrchardStorage::Error> + GetSpendableNotes(mojom::AccountIdPtr account_id); + + base::expected, ZCashOrchardStorage::Error> + GetNullifiers(mojom::AccountIdPtr account_id); + + std::optional UpdateNotes( + mojom::AccountIdPtr account_id, + OrchardBlockScanner::Result block_scanner_results, + const uint32_t latest_scanned_block, + const std::string& latest_scanned_block_hash); + + base::expected ResetAccountSyncState( + mojom::AccountIdPtr account_id); + void ResetDatabase(); + + base::expected, ZCashOrchardStorage::Error> + GetLatestShardIndex(mojom::AccountIdPtr account_id); + + base::expected UpdateSubtreeRoots( + mojom::AccountIdPtr account_id, + uint32_t start_index, + std::vector roots); + + base::expected, ZCashOrchardStorage::Error> + GetMaxCheckpointedHeight(mojom::AccountIdPtr account_id, + uint32_t chain_tip_height, + size_t min_confirmations); + + base::expected, ZCashOrchardStorage::Error> + CalculateWitnessForCheckpoint(mojom::AccountIdPtr account_id, + std::vector notes, + uint32_t checkpoint_position); + + private: + void OverrideShardTreeManagerForTesting( + const mojom::AccountIdPtr& account_id, + std::unique_ptr manager); + + OrchardShardTreeManager* GetOrCreateShardTreeManager( + const mojom::AccountIdPtr& account_id); + + scoped_refptr storage_; + std::map> + shard_tree_managers_; +}; + +} // namespace brave_wallet + +#endif // BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ZCASH_ORCHARD_SYNC_STATE_H_ diff --git a/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.cc b/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.cc index f2abbe604cc0..07f34d3bd0f1 100644 --- a/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.cc +++ b/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.cc @@ -5,6 +5,8 @@ #include "brave/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.h" +#include + #include "brave/components/brave_wallet/common/common_utils.h" #include "components/grit/brave_components_strings.h" #include "ui/base/l10n/l10n_util.h" @@ -55,8 +57,8 @@ void ZCashResolveBalanceTask::WorkOnTask() { #if BUILDFLAG(ENABLE_ORCHARD) if (IsZCashShieldedTransactionsEnabled()) { if (!orchard_notes_) { - zcash_wallet_service_->orchard_storage() - .AsyncCall(&ZCashOrchardStorage::GetSpendableNotes) + zcash_wallet_service_->sync_state() + .AsyncCall(&ZCashOrchardSyncState::GetSpendableNotes) .WithArgs(account_id_.Clone()) .Then(base::BindOnce(&ZCashResolveBalanceTask::OnGetSpendableNotes, weak_ptr_factory_.GetWeakPtr())); diff --git a/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.h b/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.h index e7b9a88719a2..bba738c3d800 100644 --- a/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.h +++ b/components/brave_wallet/browser/zcash/zcash_resolve_balance_task.h @@ -7,6 +7,8 @@ #define BRAVE_COMPONENTS_BRAVE_WALLET_BROWSER_ZCASH_ZCASH_RESOLVE_BALANCE_TASK_H_ #include +#include +#include #include "base/types/expected.h" #include "brave/components/brave_wallet/browser/zcash/zcash_wallet_service.h" diff --git a/components/brave_wallet/browser/zcash/zcash_shield_sync_service.cc b/components/brave_wallet/browser/zcash/zcash_shield_sync_service.cc index 3498ae977838..3dac379853fa 100644 --- a/components/brave_wallet/browser/zcash/zcash_shield_sync_service.cc +++ b/components/brave_wallet/browser/zcash/zcash_shield_sync_service.cc @@ -52,13 +52,13 @@ ZCashShieldSyncService::OrchardBlockScannerProxy::~OrchardBlockScannerProxy() = default; void ZCashShieldSyncService::OrchardBlockScannerProxy::ScanBlocks( - std::vector known_notes, + OrchardTreeState tree_state, std::vector blocks, base::OnceCallback)> callback) { background_block_scanner_.AsyncCall(&OrchardBlockScanner::ScanBlocks) - .WithArgs(std::move(known_notes), std::move(blocks)) + .WithArgs(std::move(tree_state), std::move(blocks)) .Then(std::move(callback)); } @@ -155,8 +155,8 @@ void ZCashShieldSyncService::WorkOnTask() { } void ZCashShieldSyncService::GetOrCreateAccount() { - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::GetAccountMeta) + sync_state() + .AsyncCall(&ZCashOrchardSyncState::GetAccountMeta) .WithArgs(account_id_.Clone()) .Then(base::BindOnce(&ZCashShieldSyncService::OnGetAccountMeta, weak_ptr_factory_.GetWeakPtr())); @@ -167,8 +167,9 @@ void ZCashShieldSyncService::OnGetAccountMeta( result) { if (result.has_value()) { account_meta_ = *result; - if (account_meta_->latest_scanned_block_id < - account_meta_->account_birthday) { + if (account_meta_->latest_scanned_block_id.value() && + (account_meta_->latest_scanned_block_id.value() < + account_meta_->account_birthday)) { error_ = Error{ErrorCode::kFailedToRetrieveAccount, ""}; } ScheduleWorkOnTask(); @@ -186,10 +187,9 @@ void ZCashShieldSyncService::OnGetAccountMeta( } void ZCashShieldSyncService::InitAccount() { - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::RegisterAccount) - .WithArgs(account_id_.Clone(), account_birthday_->value, - account_birthday_->hash) + sync_state() + .AsyncCall(&ZCashOrchardSyncState::RegisterAccount) + .WithArgs(account_id_.Clone(), account_birthday_->value) .Then(base::BindOnce(&ZCashShieldSyncService::OnAccountInit, weak_ptr_factory_.GetWeakPtr())); } @@ -207,9 +207,14 @@ void ZCashShieldSyncService::OnAccountInit( void ZCashShieldSyncService::VerifyChainState( ZCashOrchardStorage::AccountMeta account_meta) { + if (!account_meta.latest_scanned_block_id) { + latest_scanned_block_ = account_meta.account_birthday - 1; + ScheduleWorkOnTask(); + return; + } // If block chain has removed blocks we already scanned then we need to handle // chain reorg. - if (*chain_tip_block_ < account_meta.latest_scanned_block_id) { + if (*chain_tip_block_ < account_meta.latest_scanned_block_id.value()) { // Assume that chain reorg can't affect more than kChainReorgBlockDelta // blocks So we can just fallback on this number from the chain tip block. GetTreeStateForChainReorg(*chain_tip_block_ - kChainReorgBlockDelta); @@ -218,7 +223,7 @@ void ZCashShieldSyncService::VerifyChainState( // Retrieve block info for last scanned block id to check whether block hash // is the same auto block_id = zcash::mojom::BlockID::New( - account_meta.latest_scanned_block_id, std::vector()); + account_meta.latest_scanned_block_id.value(), std::vector()); zcash_rpc()->GetTreeState( chain_id_, std::move(block_id), base::BindOnce( @@ -235,12 +240,13 @@ void ZCashShieldSyncService::OnGetTreeStateForChainVerification( return; } auto backend_block_hash = RevertHex(tree_state.value()->hash); - if (backend_block_hash != account_meta.latest_scanned_block_hash) { + if (backend_block_hash != account_meta.latest_scanned_block_hash.value()) { // Assume that chain reorg can't affect more than kChainReorgBlockDelta // blocks So we can just fallback on this number. uint32_t new_block_id = - account_meta.latest_scanned_block_id > kChainReorgBlockDelta - ? account_meta.latest_scanned_block_id - kChainReorgBlockDelta + account_meta.latest_scanned_block_id.value() > kChainReorgBlockDelta + ? account_meta.latest_scanned_block_id.value() - + kChainReorgBlockDelta : 0; GetTreeStateForChainReorg(new_block_id); return; @@ -273,8 +279,8 @@ void ZCashShieldSyncService::OnGetTreeStateForChainReorg( return; } else { // Reorg database so records related to removed blocks are wiped out - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::HandleChainReorg) + sync_state() + .AsyncCall(&ZCashOrchardSyncState::HandleChainReorg) .WithArgs(account_id_.Clone(), (*tree_state)->height, (*tree_state)->hash) .Then(base::BindOnce( @@ -297,8 +303,8 @@ void ZCashShieldSyncService::OnDatabaseUpdatedForChainReorg( } void ZCashShieldSyncService::UpdateSpendableNotes() { - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::GetSpendableNotes) + sync_state() + .AsyncCall(&ZCashOrchardSyncState::GetSpendableNotes) .WithArgs(account_id_.Clone()) .Then(base::BindOnce(&ZCashShieldSyncService::OnGetSpendableNotes, weak_ptr_factory_.GetWeakPtr())); @@ -365,7 +371,7 @@ void ZCashShieldSyncService::ScanBlocks() { auto last_block_height = downloaded_blocks_->back()->height; block_scanner_->ScanBlocks( - *spendable_notes_, std::move(downloaded_blocks_.value()), + OrchardTreeState(), std::move(downloaded_blocks_.value()), base::BindOnce(&ZCashShieldSyncService::OnBlocksScanned, weak_ptr_factory_.GetWeakPtr(), last_block_height, last_block_hash)); @@ -381,20 +387,18 @@ void ZCashShieldSyncService::OnBlocksScanned( error_ = Error{ErrorCode::kScannerError, ""}; ScheduleWorkOnTask(); } else { - UpdateNotes(result->discovered_notes, result->spent_notes, - last_block_height, last_block_hash); + UpdateNotes(std::move(result.value()), last_block_height, last_block_hash); } } void ZCashShieldSyncService::UpdateNotes( - const std::vector& found_notes, - const std::vector& notes_to_delete, + OrchardBlockScanner::Result result, uint32_t latest_scanned_block, std::string latest_scanned_block_hash) { - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::UpdateNotes) - .WithArgs(account_id_.Clone(), found_notes, notes_to_delete, - latest_scanned_block, latest_scanned_block_hash) + sync_state() + .AsyncCall(&ZCashOrchardSyncState::UpdateNotes) + .WithArgs(account_id_.Clone(), std::move(result), latest_scanned_block, + latest_scanned_block_hash) .Then(base::BindOnce(&ZCashShieldSyncService::UpdateNotesComplete, weak_ptr_factory_.GetWeakPtr(), latest_scanned_block)); @@ -425,9 +429,9 @@ ZCashRpc* ZCashShieldSyncService::zcash_rpc() { return zcash_wallet_service_->zcash_rpc(); } -base::SequenceBound& -ZCashShieldSyncService::orchard_storage() { - return zcash_wallet_service_->orchard_storage(); +base::SequenceBound& +ZCashShieldSyncService::sync_state() { + return zcash_wallet_service_->sync_state(); } } // namespace brave_wallet diff --git a/components/brave_wallet/browser/zcash/zcash_shield_sync_service.h b/components/brave_wallet/browser/zcash/zcash_shield_sync_service.h index de2ed144ec44..8531ab59f505 100644 --- a/components/brave_wallet/browser/zcash/zcash_shield_sync_service.h +++ b/components/brave_wallet/browser/zcash/zcash_shield_sync_service.h @@ -15,7 +15,7 @@ #include "base/threading/sequence_bound.h" #include "base/types/expected.h" #include "brave/components/brave_wallet/browser/internal/orchard_block_scanner.h" -#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_storage.h" +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h" #include "brave/components/brave_wallet/common/brave_wallet.mojom.h" #include "mojo/public/cpp/bindings/remote.h" @@ -64,7 +64,7 @@ class ZCashShieldSyncService { explicit OrchardBlockScannerProxy(OrchardFullViewKey full_view_key); virtual ~OrchardBlockScannerProxy(); virtual void ScanBlocks( - std::vector known_notes, + OrchardTreeState tree_state, std::vector blocks, base::OnceCallback)> @@ -147,15 +147,14 @@ class ZCashShieldSyncService { std::string last_block_hash, base::expected result); - void UpdateNotes(const std::vector& found_notes, - const std::vector& notes_to_delete, + void UpdateNotes(OrchardBlockScanner::Result result, uint32_t latest_scanned_block, std::string latest_scanned_block_hash); void UpdateNotesComplete(uint32_t new_latest_scanned_block, std::optional error); ZCashRpc* zcash_rpc(); - base::SequenceBound& orchard_storage(); + base::SequenceBound& sync_state(); uint32_t GetSpendableBalance(); std::optional error() { return error_; } diff --git a/components/brave_wallet/browser/zcash/zcash_shield_sync_service_unittest.cc b/components/brave_wallet/browser/zcash/zcash_shield_sync_service_unittest.cc index 3d1b787e86b5..c8251d0fe6ca 100644 --- a/components/brave_wallet/browser/zcash/zcash_shield_sync_service_unittest.cc +++ b/components/brave_wallet/browser/zcash/zcash_shield_sync_service_unittest.cc @@ -95,8 +95,8 @@ class MockOrchardBlockScannerProxy : public ZCashShieldSyncService::OrchardBlockScannerProxy { public: using Callback = base::RepeatingCallback known_notes, - std::vector blocks, + OrchardTreeState, + std::vector, base::OnceCallback)> callback)>; @@ -107,12 +107,12 @@ class MockOrchardBlockScannerProxy ~MockOrchardBlockScannerProxy() override = default; void ScanBlocks( - std::vector known_notes, + OrchardTreeState tree_state, std::vector blocks, base::OnceCallback)> callback) override { - callback_.Run(std::move(known_notes), std::move(blocks), + callback_.Run(std::move(tree_state), std::move(blocks), std::move(callback)); } @@ -159,7 +159,7 @@ class ZCashShieldSyncServiceTest : public testing::Test { std::unique_ptr CreateMockOrchardBlockScannerProxy() { return std::make_unique(base::BindRepeating( - [](std::vector known_notes, + [](OrchardTreeState tree_state, std::vector blocks, base::OnceCallback()); for (const auto& block : blocks) { // 3 notes in the blockchain if (block->height == 105) { @@ -183,14 +185,14 @@ class ZCashShieldSyncServiceTest : public testing::Test { // First 2 notes are spent if (block->height == 255) { - result.spent_notes.push_back( - GenerateMockNullifier(account_id, block->height, 1)); + result.found_spends.push_back(OrchardNoteSpend{ + block->height, GenerateMockNullifier(account_id, 1)}); } else if (block->height == 265) { - result.spent_notes.push_back( - GenerateMockNullifier(account_id, block->height, 2)); + result.found_spends.push_back(OrchardNoteSpend{ + block->height, GenerateMockNullifier(account_id, 2)}); } } - std::move(callback).Run(result); + std::move(callback).Run(std::move(result)); })); } @@ -276,7 +278,7 @@ TEST_F(ZCashShieldSyncServiceTest, ScanBlocks) { sync_service()->SetOrchardBlockScannerProxyForTesting( std::make_unique(base::BindRepeating( - [](std::vector known_notes, + [](OrchardTreeState tree_state, std::vector blocks, base::OnceCallback()); for (const auto& block : blocks) { // 3 notes in the blockchain if (block->height == 605) { @@ -300,11 +304,11 @@ TEST_F(ZCashShieldSyncServiceTest, ScanBlocks) { // First 2 notes are spent if (block->height == 855) { - result.spent_notes.push_back( - GenerateMockNullifier(account_id, block->height, 3)); + result.found_spends.push_back(OrchardNoteSpend{ + block->height, GenerateMockNullifier(account_id, 3)}); } } - std::move(callback).Run(result); + std::move(callback).Run(std::move(result)); }))); ON_CALL(*zcash_rpc(), GetTreeState(_, _, _)) @@ -349,13 +353,15 @@ TEST_F(ZCashShieldSyncServiceTest, ScanBlocks) { sync_service()->SetOrchardBlockScannerProxyForTesting( std::make_unique(base::BindRepeating( - [](std::vector known_notes, + [](OrchardTreeState tree_state, std::vector blocks, base::OnceCallback)> callback) { - OrchardBlockScanner::Result result; - std::move(callback).Run(result); + OrchardBlockScanner::Result result = + OrchardBlockScanner::CreateResultForTesting( + tree_state, std::vector()); + std::move(callback).Run(std::move(result)); }))); { @@ -389,7 +395,7 @@ TEST_F(ZCashShieldSyncServiceTest, ScanBlocks) { sync_service()->SetOrchardBlockScannerProxyForTesting( std::make_unique(base::BindRepeating( - [](std::vector known_notes, + [](OrchardTreeState tree_state, std::vector blocks, base::OnceCallback()); for (const auto& block : blocks) { // First block is the current chain tip - kChainReorgBlockDelta EXPECT_GE(block->height, 950u - kChainReorgBlockDelta); @@ -412,11 +420,11 @@ TEST_F(ZCashShieldSyncServiceTest, ScanBlocks) { // Add a nullifier for previous note if (block->height == 905) { - result.spent_notes.push_back( - GenerateMockNullifier(account_id, block->height, 3)); + result.found_spends.push_back(OrchardNoteSpend{ + block->height, GenerateMockNullifier(account_id, 3)}); } } - std::move(callback).Run(result); + std::move(callback).Run(std::move(result)); }))); { diff --git a/components/brave_wallet/browser/zcash/zcash_test_utils.cc b/components/brave_wallet/browser/zcash/zcash_test_utils.cc index 9eaea7f89f0d..fa86a79b7b65 100644 --- a/components/brave_wallet/browser/zcash/zcash_test_utils.cc +++ b/components/brave_wallet/browser/zcash/zcash_test_utils.cc @@ -10,26 +10,24 @@ namespace brave_wallet { -std::array GenerateMockNullifier( - const mojom::AccountIdPtr& account_id, - uint8_t seed) { +OrchardNullifier GenerateMockNullifier(const mojom::AccountIdPtr& account_id, + uint8_t seed) { std::array nullifier; nullifier.fill(seed); nullifier[0] = account_id->account_index; return nullifier; } -OrchardNullifier GenerateMockNullifier(const mojom::AccountIdPtr& account_id, - uint32_t block_id, - uint8_t seed) { - return OrchardNullifier{block_id, GenerateMockNullifier(account_id, seed)}; -} - OrchardNote GenerateMockOrchardNote(const mojom::AccountIdPtr& account_id, uint32_t block_id, uint8_t seed) { - return OrchardNote{block_id, GenerateMockNullifier(account_id, seed), - static_cast(seed * 10)}; + return OrchardNote{{}, + block_id, + GenerateMockNullifier(account_id, seed), + static_cast(seed * 10), + 0, + {}, + {}}; } void SortByBlockId(std::vector& vec) { diff --git a/components/brave_wallet/browser/zcash/zcash_test_utils.h b/components/brave_wallet/browser/zcash/zcash_test_utils.h index 7131666efb06..08a56304e33e 100644 --- a/components/brave_wallet/browser/zcash/zcash_test_utils.h +++ b/components/brave_wallet/browser/zcash/zcash_test_utils.h @@ -13,12 +13,7 @@ namespace brave_wallet { -std::array GenerateMockNullifier( - const mojom::AccountIdPtr& account_id, - uint8_t seed); - OrchardNullifier GenerateMockNullifier(const mojom::AccountIdPtr& account_id, - uint32_t block_id, uint8_t seed); OrchardNote GenerateMockOrchardNote(const mojom::AccountIdPtr& account_id, diff --git a/components/brave_wallet/browser/zcash/zcash_wallet_service.cc b/components/brave_wallet/browser/zcash/zcash_wallet_service.cc index e8b49a22d086..89015291eab2 100644 --- a/components/brave_wallet/browser/zcash/zcash_wallet_service.cc +++ b/components/brave_wallet/browser/zcash/zcash_wallet_service.cc @@ -64,7 +64,7 @@ ZCashWalletService::ZCashWalletService( keyring_observer_receiver_.BindNewPipeAndPassRemote()); complete_manager_ = std::make_unique(this); #if BUILDFLAG(ENABLE_ORCHARD) - background_orchard_storage_.emplace( + sync_state_.emplace( base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}), zcash_data_path_.AppendASCII(kOrchardDatabaseName)); #endif @@ -83,7 +83,7 @@ ZCashWalletService::ZCashWalletService(base::FilePath zcash_data_path, } complete_manager_ = std::make_unique(this); #if BUILDFLAG(ENABLE_ORCHARD) - background_orchard_storage_.emplace( + sync_state_.emplace( base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}), zcash_data_path_.AppendASCII(kOrchardDatabaseName)); #endif @@ -754,9 +754,8 @@ KeyringService* ZCashWalletService::keyring_service() { } #if BUILDFLAG(ENABLE_ORCHARD) -base::SequenceBound& -ZCashWalletService::orchard_storage() { - return background_orchard_storage_; +base::SequenceBound& ZCashWalletService::sync_state() { + return sync_state_; } #endif // BUILDFLAG(ENABLE_ORCHARD) @@ -779,7 +778,7 @@ void ZCashWalletService::Reset() { weak_ptr_factory_.InvalidateWeakPtrs(); #if BUILDFLAG(ENABLE_ORCHARD) shield_sync_services_.clear(); - background_orchard_storage_.AsyncCall(&ZCashOrchardStorage::ResetDatabase); + sync_state_.AsyncCall(&ZCashOrchardSyncState::ResetDatabase); #endif // BUILDFLAG(ENABLE_ORCHARD) } diff --git a/components/brave_wallet/browser/zcash/zcash_wallet_service.h b/components/brave_wallet/browser/zcash/zcash_wallet_service.h index 213c02f87bb1..65eab30a76f4 100644 --- a/components/brave_wallet/browser/zcash/zcash_wallet_service.h +++ b/components/brave_wallet/browser/zcash/zcash_wallet_service.h @@ -25,11 +25,16 @@ #include "brave/components/brave_wallet/common/buildflags.h" #include "brave/components/brave_wallet/common/zcash_utils.h" +#if BUILDFLAG(ENABLE_ORCHARD) +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h" +#endif + namespace brave_wallet { class ZCashCreateShieldTransactionTask; class ZCashCreateTransparentTransactionTask; class ZCashGetTransparentUtxosContext; +class ZCashOrchardSyncState; class ZCashResolveBalanceTask; class ZCashWalletService : public mojom::ZCashWalletService, @@ -250,7 +255,7 @@ class ZCashWalletService : public mojom::ZCashWalletService, const mojom::AccountIdPtr& account_id, const mojom::ZCashShieldSyncStatusPtr& status) override; - base::SequenceBound& orchard_storage(); + base::SequenceBound& sync_state(); #endif void UpdateNextUnusedAddressForAccount(const mojom::AccountIdPtr& account_id, @@ -269,7 +274,7 @@ class ZCashWalletService : public mojom::ZCashWalletService, std::list> resolve_balance_tasks_; #if BUILDFLAG(ENABLE_ORCHARD) - base::SequenceBound background_orchard_storage_; + base::SequenceBound sync_state_; std::list> create_shield_transaction_tasks_; std::map> diff --git a/components/brave_wallet/browser/zcash/zcash_wallet_service_unittest.cc b/components/brave_wallet/browser/zcash/zcash_wallet_service_unittest.cc index a2f66cd23ded..01842a7e0c6d 100644 --- a/components/brave_wallet/browser/zcash/zcash_wallet_service_unittest.cc +++ b/components/brave_wallet/browser/zcash/zcash_wallet_service_unittest.cc @@ -36,7 +36,7 @@ #include "base/task/sequenced_task_runner.h" #include "base/test/scoped_run_loop_timeout.h" #include "brave/components/brave_wallet/browser/internal/orchard_bundle_manager.h" -#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_storage.h" +#include "brave/components/brave_wallet/browser/zcash/zcash_orchard_sync_state.h" #endif using testing::_; @@ -141,8 +141,8 @@ class ZCashWalletServiceUnitTest : public testing::Test { } #if BUILDFLAG(ENABLE_ORCHARD) - base::SequenceBound& orchard_storage() { - return zcash_wallet_service_->orchard_storage(); + base::SequenceBound& sync_state() { + return zcash_wallet_service_->sync_state(); } #endif // BUILDFLAG(ENABLE_ORCHARD) @@ -328,10 +328,14 @@ TEST_F(ZCashWalletServiceUnitTest, GetBalanceWithShielded) { auto update_notes_callback = base::BindLambdaForTesting( [](std::optional) {}); - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::UpdateNotes) - .WithArgs(account_id.Clone(), std::vector({note}), - std::vector(), 50000, "hash50000") + OrchardBlockScanner::Result result = + OrchardBlockScanner::CreateResultForTesting( + OrchardTreeState(), std::vector()); + result.discovered_notes = std::vector({note}); + + sync_state() + .AsyncCall(&ZCashOrchardSyncState::UpdateNotes) + .WithArgs(account_id.Clone(), std::move(result), 50000, "hash50000") .Then(std::move(update_notes_callback)); task_environment_.RunUntilIdle(); @@ -413,10 +417,14 @@ TEST_F(ZCashWalletServiceUnitTest, GetBalanceWithShielded_FeatureDisabled) { auto update_notes_callback = base::BindLambdaForTesting( [](std::optional) {}); - orchard_storage() - .AsyncCall(&ZCashOrchardStorage::UpdateNotes) - .WithArgs(account_id.Clone(), std::vector({note}), - std::vector(), 50000, "hash50000") + OrchardBlockScanner::Result result = + OrchardBlockScanner::CreateResultForTesting( + OrchardTreeState(), std::vector()); + result.discovered_notes = std::vector({note}); + + sync_state() + .AsyncCall(&ZCashOrchardSyncState::UpdateNotes) + .WithArgs(account_id.Clone(), std::move(result), 50000, "hash50000") .Then(std::move(update_notes_callback)); task_environment_.RunUntilIdle(); diff --git a/components/brave_wallet/common/zcash_utils.cc b/components/brave_wallet/common/zcash_utils.cc index 997ec72b76d9..cfd33d403319 100644 --- a/components/brave_wallet/common/zcash_utils.cc +++ b/components/brave_wallet/common/zcash_utils.cc @@ -108,6 +108,108 @@ std::vector GetNetworkPrefix(bool is_testnet) { } // namespace +// static + +// static +base::Value::Dict OrchardNote::ToValue() const { + base::Value::Dict dict; + + dict.Set("addr", base::HexEncode(addr.data(), addr.size())); + dict.Set("block_id", base::NumberToString(block_id)); + dict.Set("nullifier", base::HexEncode(nullifier.data(), nullifier.size())); + dict.Set("amount", base::NumberToString(amount)); + dict.Set("orchard_commitment_tree_position", + base::NumberToString(orchard_commitment_tree_position)); + dict.Set("rho", base::HexEncode(rho.data(), rho.size())); + dict.Set("seed", base::HexEncode(seed.data(), seed.size())); + + return dict; +} + +// static +std::optional OrchardNote::FromValue( + const base::Value::Dict& value) { + OrchardNote result; + if (!ReadHexByteArrayTo(value, "addr", result.addr)) { + return std::nullopt; + } + + if (!ReadUint32StringTo(value, "block_id", result.block_id)) { + return std::nullopt; + } + + if (!ReadHexByteArrayTo(value, "nullifier", + result.nullifier)) { + return std::nullopt; + } + + if (!ReadUint32StringTo(value, "amount", result.amount)) { + return std::nullopt; + } + + if (!ReadUint32StringTo(value, "orchard_commitment_tree_position", + result.orchard_commitment_tree_position)) { + return std::nullopt; + } + + if (!ReadHexByteArrayTo(value, "rho", result.rho)) { + return std::nullopt; + } + + if (!ReadHexByteArrayTo(value, "seed", result.seed)) { + return std::nullopt; + } + + return result; +} + +// static +base::Value::Dict OrchardInput::ToValue() const { + base::Value::Dict dict; + + // Do not serialize witness ATM since it is calculated before post + dict.Set("note", note.ToValue()); + + return dict; +} + +// static +std::optional OrchardInput::FromValue( + const base::Value::Dict& value) { + OrchardInput result; + + auto* note_dict = value.FindDict("note"); + if (!note_dict) { + return std::nullopt; + } + auto note = OrchardNote::FromValue(*note_dict); + if (!note) { + return std::nullopt; + } + + result.note = *note; + + return result; +} + +OrchardTreeState::OrchardTreeState() {} +OrchardTreeState::~OrchardTreeState() {} +OrchardTreeState::OrchardTreeState(const OrchardTreeState&) = default; + +OrchardNoteWitness::OrchardNoteWitness() = default; +OrchardNoteWitness::~OrchardNoteWitness() = default; +OrchardNoteWitness::OrchardNoteWitness(const OrchardNoteWitness& other) = + default; + +OrchardInput::OrchardInput() = default; +OrchardInput::~OrchardInput() = default; +OrchardInput::OrchardInput(const OrchardInput& other) = default; + +OrchardSpendsBundle::OrchardSpendsBundle() = default; +OrchardSpendsBundle::~OrchardSpendsBundle() = default; +OrchardSpendsBundle::OrchardSpendsBundle(const OrchardSpendsBundle& other) = + default; + DecodedZCashAddress::DecodedZCashAddress() = default; DecodedZCashAddress::~DecodedZCashAddress() = default; DecodedZCashAddress::DecodedZCashAddress(const DecodedZCashAddress& other) = @@ -154,6 +256,52 @@ std::optional OrchardOutput::FromValue( return result; } +OrchardCheckpoint::OrchardCheckpoint() {} +OrchardCheckpoint::OrchardCheckpoint(CheckpointTreeState tree_state_position, + std::vector marks_removed) + : tree_state_position(tree_state_position), + marks_removed(std::move(marks_removed)) {} +OrchardCheckpoint::~OrchardCheckpoint() {} +OrchardCheckpoint::OrchardCheckpoint(const OrchardCheckpoint& other) = default; +OrchardCheckpoint& OrchardCheckpoint::operator=( + const OrchardCheckpoint& other) = default; +OrchardCheckpoint::OrchardCheckpoint(OrchardCheckpoint&& other) = default; +OrchardCheckpoint& OrchardCheckpoint::operator=(OrchardCheckpoint&& other) = + default; + +OrchardCheckpointBundle::OrchardCheckpointBundle(uint32_t checkpoint_id, + OrchardCheckpoint checkpoint) + : checkpoint_id(checkpoint_id), checkpoint(std::move(checkpoint)) {} +OrchardCheckpointBundle::~OrchardCheckpointBundle() {} +OrchardCheckpointBundle::OrchardCheckpointBundle( + const OrchardCheckpointBundle& other) = default; +OrchardCheckpointBundle& OrchardCheckpointBundle::operator=( + const OrchardCheckpointBundle& other) = default; +OrchardCheckpointBundle::OrchardCheckpointBundle( + OrchardCheckpointBundle&& other) = default; +OrchardCheckpointBundle& OrchardCheckpointBundle::operator=( + OrchardCheckpointBundle&& other) = default; + +OrchardShard::OrchardShard() {} +OrchardShard::OrchardShard(OrchardShardAddress address, + std::optional root_hash, + std::vector shard_data) + : address(std::move(address)), + root_hash(std::move(root_hash)), + shard_data(std::move(shard_data)) {} +OrchardShard::~OrchardShard() = default; +OrchardShard::OrchardShard(const OrchardShard& other) = default; +OrchardShard& OrchardShard::operator=(const OrchardShard& other) = default; +OrchardShard::OrchardShard(OrchardShard&& other) = default; +OrchardShard& OrchardShard::operator=(OrchardShard&& other) = default; + +OrchardCap::OrchardCap() = default; +OrchardCap::~OrchardCap() = default; +OrchardCap::OrchardCap(const OrchardCap& other) = default; +OrchardCap& OrchardCap::operator=(const OrchardCap& other) = default; +OrchardCap::OrchardCap(OrchardCap&& other) = default; +OrchardCap& OrchardCap::operator=(OrchardCap&& other) = default; + bool OutputZCashAddressSupported(const std::string& address, bool is_testnet) { auto decoded_address = DecodeZCashAddress(address); if (!decoded_address) { diff --git a/components/brave_wallet/common/zcash_utils.h b/components/brave_wallet/common/zcash_utils.h index d10ff3c955c3..a701d6a421e6 100644 --- a/components/brave_wallet/common/zcash_utils.h +++ b/components/brave_wallet/common/zcash_utils.h @@ -13,6 +13,7 @@ #include #include "base/containers/span.h" +#include "base/types/expected.h" #include "brave/components/brave_wallet/common/brave_wallet.mojom.h" namespace brave_wallet { @@ -35,10 +36,24 @@ inline constexpr size_t kOrchardCipherTextSize = 52u; inline constexpr size_t kOrchardMemoSize = 512u; inline constexpr uint64_t kZCashFullAmount = std::numeric_limits::max(); +inline constexpr size_t kOrchardShardTreeHashSize = 32u; +inline constexpr uint8_t kOrchardShardSubtreeHeight = 8; +inline constexpr uint8_t kOrchardShardTreeHeight = 32; +inline constexpr uint8_t kOrchardNoteRhoSize = 32; +inline constexpr uint8_t kOrchardNoteRSeedSize = 32; +inline constexpr uint8_t kOrchardSpendingKeySize = 32; +inline constexpr size_t kOrchardCompleteBlockHashSize = 32u; using OrchardFullViewKey = std::array; using OrchardMemo = std::array; using OrchardAddrRawPart = std::array; +using OrchardRho = std::array; +using OrchardRseed = std::array; +using OrchardMerkleHash = std::array; +using OrchardNullifier = std::array; +using OrchardShardRootHash = std::array; +using OrchardCommitmentValue = std::array; +using OrchardSpendingKey = std::array; // Reduce current scanning position on this value if reorg is found // All Zcash network participants basically assume rollbacks longer than 100 @@ -89,21 +104,212 @@ struct OrchardOutput { }; // Structure describes note nullifier that marks some note as spent -struct OrchardNullifier { +struct OrchardNoteSpend { // Block id where spent nullifier was met uint32_t block_id = 0; std::array nullifier; - bool operator==(const OrchardNullifier& other) const = default; + bool operator==(const OrchardNoteSpend& other) const = default; }; -// Structure describes found spendable note +// Describes spendable note. +// Spendable note contains related position +// in the Orchard commitment tree, amount and data required +// for costructing zk-proof for spending. struct OrchardNote { + OrchardAddrRawPart addr; uint32_t block_id = 0; - std::array nullifier; + OrchardNullifier nullifier; uint32_t amount = 0; + uint32_t orchard_commitment_tree_position = 0; + OrchardRho rho; + OrchardRseed seed; bool operator==(const OrchardNote& other) const = default; + base::Value::Dict ToValue() const; + static std::optional FromValue(const base::Value::Dict& value); +}; + +// Note witness is a Merkle path in the Orchard commitment tree from the +// note to the tree root according some selected anchor(selected right border in +// the commitment tree). +struct OrchardNoteWitness { + OrchardNoteWitness(); + ~OrchardNoteWitness(); + OrchardNoteWitness(const OrchardNoteWitness& other); + + uint32_t position; + std::vector merkle_path; + bool operator==(const OrchardNoteWitness& other) const = default; +}; + +// Data required for constructing note spend. +struct OrchardInput { + OrchardInput(); + ~OrchardInput(); + OrchardInput(const OrchardInput& other); + + OrchardNote note; + std::optional witness; + + base::Value::Dict ToValue() const; + static std::optional FromValue(const base::Value::Dict& value); +}; + +// Set of Orchard inputs along with keys needed for signing. +struct OrchardSpendsBundle { + OrchardSpendsBundle(); + ~OrchardSpendsBundle(); + OrchardSpendsBundle(const OrchardSpendsBundle& other); + + OrchardSpendingKey sk; + OrchardFullViewKey fvk; + std::vector inputs; +}; + +using CheckpointTreeState = std::optional; + +struct OrchardCheckpoint { + OrchardCheckpoint(); + OrchardCheckpoint(CheckpointTreeState, std::vector); + ~OrchardCheckpoint(); + OrchardCheckpoint(const OrchardCheckpoint& other); + OrchardCheckpoint& operator=(const OrchardCheckpoint& other); + OrchardCheckpoint(OrchardCheckpoint&& other); + OrchardCheckpoint& operator=(OrchardCheckpoint&& other); + + bool operator==(const OrchardCheckpoint& other) const = default; + + CheckpointTreeState tree_state_position; + std::vector marks_removed; +}; + +struct OrchardCheckpointBundle { + OrchardCheckpointBundle(uint32_t checkpoint_id, OrchardCheckpoint); + ~OrchardCheckpointBundle(); + OrchardCheckpointBundle(const OrchardCheckpointBundle& other); + OrchardCheckpointBundle& operator=(const OrchardCheckpointBundle& other); + OrchardCheckpointBundle(OrchardCheckpointBundle&& other); + OrchardCheckpointBundle& operator=(OrchardCheckpointBundle&& other); + + bool operator==(const OrchardCheckpointBundle& other) const = default; + + uint32_t checkpoint_id; + OrchardCheckpoint checkpoint; +}; + +// Address of subtree in the shard tree +struct OrchardShardAddress { + uint8_t level = 0; + uint32_t index = 0; + + bool operator==(const OrchardShardAddress& other) const = default; +}; + +// Top part of the shard tree from the root to the shard roots level +// Used for optimization purposes in the shard tree crate. +struct OrchardCap { + OrchardCap(); + ~OrchardCap(); + + OrchardCap(const OrchardCap& other); + OrchardCap& operator=(const OrchardCap& other); + OrchardCap(OrchardCap&& other); + OrchardCap& operator=(OrchardCap&& other); + + std::vector data; +}; + +// Subtree with root selected from the shard roots level. +struct OrchardShard { + OrchardShard(); + OrchardShard(OrchardShardAddress shard_addr, + std::optional shard_hash, + std::vector shard_data); + ~OrchardShard(); + + OrchardShard(const OrchardShard& other); + OrchardShard& operator=(const OrchardShard& other); + OrchardShard(OrchardShard&& other); + OrchardShard& operator=(OrchardShard&& other); + + bool operator==(const OrchardShard& other) const = default; + + // Subtree root address. + OrchardShardAddress address; + // Root hash exists only on completed shards. + std::optional root_hash; + std::vector shard_data; + // Right-most position of the subtree leaf. + size_t subtree_end_height = 0; +}; + +struct OrchardCommitment { + OrchardCommitmentValue cmu; + bool is_marked; + std::optional checkpoint_id; +}; + +// Compact representation of the Merkle tree on some point. +// Since batch inserting may contain gaps between scan ranges we insert +// frontier which allows to calculate node hashes and witnesses(merkle path from +// leaf to the tree root) even when previous scan ranges are not completed. +struct OrchardTreeState { + OrchardTreeState(); + ~OrchardTreeState(); + OrchardTreeState(const OrchardTreeState&); + + // Tree state is linked to the end of some block. + uint32_t block_height = 0u; + // Number of leafs at the position. + uint32_t tree_size = 0u; + // https://docs.aztec.network/protocol-specs/l1-smart-contracts/frontier + std::vector frontier; +}; + +class OrchardShardTreeDelegate { + public: + enum Error { kStorageError = 0, kConsistensyError = 1 }; + + virtual ~OrchardShardTreeDelegate() {} + + virtual base::expected, Error> GetCap() const = 0; + virtual base::expected PutCap(OrchardCap shard) = 0; + + virtual base::expected, Error> GetLatestShardIndex() + const = 0; + virtual base::expected PutShard(OrchardShard shard) = 0; + virtual base::expected, Error> GetShard( + OrchardShardAddress address) const = 0; + virtual base::expected, Error> LastShard( + uint8_t shard_height) const = 0; + virtual base::expected Truncate(uint32_t block_height) = 0; + + virtual base::expected TruncateCheckpoints( + uint32_t checkpoint_id) = 0; + virtual base::expected CheckpointCount() const = 0; + virtual base::expected, Error> MinCheckpointId() + const = 0; + virtual base::expected, Error> MaxCheckpointId() + const = 0; + virtual base::expected, Error> GetCheckpointAtDepth( + uint32_t depth) const = 0; + virtual base::expected, Error> + GetCheckpoint(uint32_t checkpoint_id) const = 0; + virtual base::expected, Error> + GetCheckpoints(size_t limit) const = 0; + virtual base::expected RemoveCheckpointAt(uint32_t depth) = 0; + virtual base::expected RemoveCheckpoint( + uint32_t checkpoint_id) = 0; + virtual base::expected AddCheckpoint( + uint32_t id, + OrchardCheckpoint checkpoint) = 0; + virtual base::expected UpdateCheckpoint( + uint32_t id, + OrchardCheckpoint checkpoint) = 0; + + virtual base::expected, Error> GetShardRoots( + uint8_t shard_level) const = 0; }; bool OutputZCashAddressSupported(const std::string& address, bool is_testnet); diff --git a/components/services/brave_wallet/public/mojom/zcash_decoder.mojom b/components/services/brave_wallet/public/mojom/zcash_decoder.mojom index df368e94e250..2dfa39ef247b 100644 --- a/components/services/brave_wallet/public/mojom/zcash_decoder.mojom +++ b/components/services/brave_wallet/public/mojom/zcash_decoder.mojom @@ -71,6 +71,12 @@ struct CompactBlock { ChainMetadata chain_metadata; }; +struct SubtreeRoot { + array root_hash; + array complete_block_hash; + uint32 complete_block_height; +}; + interface ZCashDecoder { ParseBlockID(string data) => (BlockID? value); ParseGetAddressUtxos(string data) => (GetAddressUtxosResponse? value);