From b2bc8a99659ebf3e040eed81b46c4eb1ba8f8052 Mon Sep 17 00:00:00 2001 From: Dmitry S <11892559+swift1337@users.noreply.github.com> Date: Mon, 22 Jul 2024 20:04:12 +0200 Subject: [PATCH] feat(zetaclient)!: support for runtime chain (de)provisioning (#2497) * Simplify orchestrator construction * Implement signers provision / deprovision based on AppContext state * Add db package * Fix observers so they can use *db.DB * Move observer map creation to orchestrator * Move observer.Start() to orchestrator. Shutdown zetaclient if not an observer * Implement BTC & EVM RPCs as httptest wrapper * Implement observer map sync based on chainParams * Implement observer & signer sync worker * Respect chainParam.IsSupported * Update readme * Fix conflicts * Fix SOL observer * Address PR comments [1] * Address PR comments [2] * Address PR comments [3] * Update orchestrator sync cadence * Address PR comments [4] --- changelog.md | 1 + cmd/zetaclientd/start.go | 83 +-- cmd/zetaclientd/utils.go | 196 +------ pkg/ptr/ptr.go | 17 + zetaclient/chains/base/observer.go | 112 ++-- zetaclient/chains/base/observer_test.go | 120 +--- .../chains/bitcoin/observer/observer.go | 62 +- .../chains/bitcoin/observer/observer_test.go | 160 ++---- .../chains/bitcoin/observer/outbound_test.go | 6 +- .../chains/bitcoin/rpc/rpc_live_test.go | 8 +- zetaclient/chains/bitcoin/signer/signer.go | 3 +- .../chains/evm/observer/inbound_test.go | 36 +- zetaclient/chains/evm/observer/observer.go | 47 +- .../chains/evm/observer/observer_test.go | 104 ++-- .../chains/evm/observer/outbound_test.go | 14 +- zetaclient/chains/evm/signer/signer_test.go | 7 +- zetaclient/chains/solana/observer/db.go | 30 - .../chains/solana/observer/inbound_test.go | 26 +- zetaclient/chains/solana/observer/observer.go | 31 +- .../observer/{db_test.go => observer_test.go} | 40 +- zetaclient/context/app.go | 80 ++- zetaclient/context/app_test.go | 52 +- zetaclient/db/db.go | 112 ++++ zetaclient/db/db_test.go | 96 ++++ zetaclient/orchestrator/bootstap_test.go | 535 ++++++++++++++++++ zetaclient/orchestrator/bootstrap.go | 407 +++++++++++++ zetaclient/orchestrator/mapping.go | 53 ++ zetaclient/orchestrator/orchestrator.go | 324 +++++++---- zetaclient/orchestrator/orchestrator_test.go | 16 +- zetaclient/testutils/constant.go | 3 - zetaclient/testutils/testrpc/rpc.go | 87 +++ zetaclient/testutils/testrpc/rpc_btc.go | 52 ++ zetaclient/testutils/testrpc/rpc_evm.go | 29 + zetaclient/testutils/testrpc/rpc_solana.go | 21 + 34 files changed, 2036 insertions(+), 934 deletions(-) create mode 100644 pkg/ptr/ptr.go delete mode 100644 zetaclient/chains/solana/observer/db.go rename zetaclient/chains/solana/observer/{db_test.go => observer_test.go} (64%) create mode 100644 zetaclient/db/db.go create mode 100644 zetaclient/db/db_test.go create mode 100644 zetaclient/orchestrator/bootstap_test.go create mode 100644 zetaclient/orchestrator/bootstrap.go create mode 100644 zetaclient/orchestrator/mapping.go create mode 100644 zetaclient/testutils/testrpc/rpc.go create mode 100644 zetaclient/testutils/testrpc/rpc_btc.go create mode 100644 zetaclient/testutils/testrpc/rpc_evm.go create mode 100644 zetaclient/testutils/testrpc/rpc_solana.go diff --git a/changelog.md b/changelog.md index fb481eae11..48fa09bb9a 100644 --- a/changelog.md +++ b/changelog.md @@ -32,6 +32,7 @@ * [2372](https://github.com/zeta-chain/node/pull/2372) - add queries for tss fund migration info * [2416](https://github.com/zeta-chain/node/pull/2416) - add Solana chain information * [2465](https://github.com/zeta-chain/node/pull/2465) - add Solana inbound SOL token observation +* [2497](https://github.com/zeta-chain/node/pull/2416) - support for runtime chain (de)provisioning * [2518](https://github.com/zeta-chain/node/pull/2518) - add support for Solana address in zetacore ### Refactor diff --git a/cmd/zetaclientd/start.go b/cmd/zetaclientd/start.go index 37251c04af..bc0b91ec2a 100644 --- a/cmd/zetaclientd/start.go +++ b/cmd/zetaclientd/start.go @@ -30,6 +30,7 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/metrics" "github.com/zeta-chain/zetacore/zetaclient/orchestrator" mc "github.com/zeta-chain/zetacore/zetaclient/tss" + "github.com/zeta-chain/zetacore/zetaclient/zetacore" ) var StartCmd = &cobra.Command{ @@ -74,7 +75,7 @@ func start(_ *cobra.Command, _ []string) error { } masterLogger := logger.Std - startLogger := masterLogger.With().Str("module", "startup").Logger() + startLogger := logger.Std.With().Str("module", "startup").Logger() appContext := zctx.New(cfg, masterLogger) ctx := zctx.WithAppContext(context.Background(), appContext) @@ -269,23 +270,22 @@ func start(_ *cobra.Command, _ []string) error { startLogger.Error().Msgf("No chains enabled in updated config %s ", cfg.String()) } - observerList, err := zetacoreClient.GetObserverList(ctx) - if err != nil { - startLogger.Error().Err(err).Msg("GetObserverList error") + isObserver, err := isObserverNode(ctx, zetacoreClient) + switch { + case err != nil: + startLogger.Error().Msgf("Unable to determine if node is an observer") return err - } - isNodeActive := false - for _, observer := range observerList { - if observer == zetacoreClient.GetKeys().GetOperatorAddress().String() { - isNodeActive = true - break - } + case !isObserver: + addr := zetacoreClient.GetKeys().GetOperatorAddress().String() + startLogger.Info().Str("operator_address", addr).Msg("This node is not an observer. Exit 0") + return nil } - // CreateSignerMap: This creates a map of all signers for each chain . Each signer is responsible for signing transactions for a particular chain - signerMap, err := CreateSignerMap(ctx, appContext, tss, logger, telemetryServer) + // CreateSignerMap: This creates a map of all signers for each chain. + // Each signer is responsible for signing transactions for a particular chain + signerMap, err := orchestrator.CreateSignerMap(ctx, tss, logger, telemetryServer) if err != nil { - log.Error().Err(err).Msg("CreateSignerMap") + log.Error().Err(err).Msg("Unable to create signer map") return err } @@ -296,35 +296,34 @@ func start(_ *cobra.Command, _ []string) error { } dbpath := filepath.Join(userDir, ".zetaclient/chainobserver") - // Creates a map of all chain observers for each chain. Each chain observer is responsible for observing events on the chain and processing them. - observerMap, err := CreateChainObserverMap(ctx, appContext, zetacoreClient, tss, dbpath, logger, telemetryServer) + // Creates a map of all chain observers for each chain. + // Each chain observer is responsible for observing events on the chain and processing them. + observerMap, err := orchestrator.CreateChainObserverMap(ctx, zetacoreClient, tss, dbpath, logger, telemetryServer) if err != nil { startLogger.Err(err).Msg("CreateChainObserverMap") return err } - if !isNodeActive { - startLogger.Error(). - Msgf("Node %s is not an active observer external chain observers will not be started", zetacoreClient.GetKeys().GetOperatorAddress().String()) - } else { - startLogger.Debug().Msgf("Node %s is an active observer starting external chain observers", zetacoreClient.GetKeys().GetOperatorAddress().String()) - for _, observer := range observerMap { - observer.Start(ctx) - } - } - - // Orchestrator wraps the zetacore client and adds the observers and signer maps to it . This is the high level object used for CCTX interactions - orchestrator := orchestrator.NewOrchestrator( + // Orchestrator wraps the zetacore client and adds the observers and signer maps to it. + // This is the high level object used for CCTX interactions + maestro, err := orchestrator.New( ctx, zetacoreClient, signerMap, observerMap, - masterLogger, + tss, + dbpath, + logger, telemetryServer, ) - err = orchestrator.MonitorCore(ctx) if err != nil { - startLogger.Error().Err(err).Msg("Orchestrator failed to start") + startLogger.Error().Err(err).Msg("Unable to create orchestrator") + return err + } + + // Start orchestrator with all observers and signers + if err := maestro.Start(ctx); err != nil { + startLogger.Error().Err(err).Msg("Unable to start orchestrator") return err } @@ -348,10 +347,6 @@ func start(_ *cobra.Command, _ []string) error { sig := <-ch startLogger.Info().Msgf("stop signal received: %s", sig) - // stop chain observers - for _, observer := range observerMap { - observer.Stop() - } zetacoreClient.Stop() return nil @@ -415,3 +410,21 @@ func promptPasswords() (string, string, error) { return hotKeyPass, TSSKeyPass, err } + +// isObserverNode checks whether THIS node is an observer node. +func isObserverNode(ctx context.Context, client *zetacore.Client) (bool, error) { + observers, err := client.GetObserverList(ctx) + if err != nil { + return false, errors.Wrap(err, "unable to get observers list") + } + + operatorAddress := client.GetKeys().GetOperatorAddress().String() + + for _, observer := range observers { + if observer == operatorAddress { + return true, nil + } + } + + return false, nil +} diff --git a/cmd/zetaclientd/utils.go b/cmd/zetaclientd/utils.go index a7799eadd4..d2b5801bef 100644 --- a/cmd/zetaclientd/utils.go +++ b/cmd/zetaclientd/utils.go @@ -1,28 +1,12 @@ package main import ( - gocontext "context" - "fmt" - sdk "github.com/cosmos/cosmos-sdk/types" - ethcommon "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethclient" - solrpc "github.com/gagliardetto/solana-go/rpc" "github.com/rs/zerolog" "github.com/zeta-chain/zetacore/zetaclient/authz" - "github.com/zeta-chain/zetacore/zetaclient/chains/base" - btcobserver "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/observer" - btcrpc "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/rpc" - btcsigner "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/signer" - evmobserver "github.com/zeta-chain/zetacore/zetaclient/chains/evm/observer" - evmsigner "github.com/zeta-chain/zetacore/zetaclient/chains/evm/signer" - "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" - solanaobserver "github.com/zeta-chain/zetacore/zetaclient/chains/solana/observer" "github.com/zeta-chain/zetacore/zetaclient/config" - "github.com/zeta-chain/zetacore/zetaclient/context" "github.com/zeta-chain/zetacore/zetaclient/keys" - "github.com/zeta-chain/zetacore/zetaclient/metrics" "github.com/zeta-chain/zetacore/zetaclient/zetacore" ) @@ -58,182 +42,4 @@ func CreateZetacoreClient(cfg config.Config, hotkeyPassword string, logger zerol return client, nil } -// CreateSignerMap creates a map of ChainSigners for all chains in the config -func CreateSignerMap( - ctx gocontext.Context, - appContext *context.AppContext, - tss interfaces.TSSSigner, - logger base.Logger, - ts *metrics.TelemetryServer, -) (map[int64]interfaces.ChainSigner, error) { - signerMap := make(map[int64]interfaces.ChainSigner) - - // EVM signers - for _, evmConfig := range appContext.Config().GetAllEVMConfigs() { - if evmConfig.Chain.IsZetaChain() { - continue - } - evmChainParams, found := appContext.GetEVMChainParams(evmConfig.Chain.ChainId) - if !found { - logger.Std.Error().Msgf("ChainParam not found for chain %s", evmConfig.Chain.String()) - continue - } - - chainName := evmConfig.Chain.ChainName.String() - mpiAddress := ethcommon.HexToAddress(evmChainParams.ConnectorContractAddress) - erc20CustodyAddress := ethcommon.HexToAddress(evmChainParams.Erc20CustodyContractAddress) - - signer, err := evmsigner.NewSigner( - ctx, - evmConfig.Chain, - tss, - ts, - logger, - evmConfig.Endpoint, - config.GetConnectorABI(), - config.GetERC20CustodyABI(), - mpiAddress, - erc20CustodyAddress, - ) - if err != nil { - logger.Std.Error().Err(err).Msgf("NewSigner error for EVM chain %q", chainName) - continue - } - - signerMap[evmConfig.Chain.ChainId] = signer - logger.Std.Info().Msgf("NewSigner succeeded for EVM chain %q", chainName) - } - - // BTC signer - btcChain, btcConfig, btcEnabled := appContext.GetBTCChainAndConfig() - if btcEnabled { - chainName := btcChain.ChainName.String() - - signer, err := btcsigner.NewSigner(btcChain, tss, ts, logger, btcConfig) - if err != nil { - logger.Std.Error().Err(err).Msgf("NewSigner error for BTC chain %q", chainName) - } else { - signerMap[btcChain.ChainId] = signer - logger.Std.Info().Msgf("NewSigner succeeded for BTC chain %q", chainName) - } - } - - return signerMap, nil -} - -// CreateChainObserverMap creates a map of ChainObservers for all chains in the config -func CreateChainObserverMap( - ctx gocontext.Context, - appContext *context.AppContext, - zetacoreClient *zetacore.Client, - tss interfaces.TSSSigner, - dbpath string, - logger base.Logger, - ts *metrics.TelemetryServer, -) (map[int64]interfaces.ChainObserver, error) { - observerMap := make(map[int64]interfaces.ChainObserver) - // EVM observers - for _, evmConfig := range appContext.Config().GetAllEVMConfigs() { - if evmConfig.Chain.IsZetaChain() { - continue - } - chainParams, found := appContext.GetEVMChainParams(evmConfig.Chain.ChainId) - if !found { - logger.Std.Error().Msgf("ChainParam not found for chain %s", evmConfig.Chain.String()) - continue - } - - // create EVM client - evmClient, err := ethclient.Dial(evmConfig.Endpoint) - if err != nil { - logger.Std.Error().Err(err).Msgf("error dailing endpoint %q", evmConfig.Endpoint) - continue - } - - // create EVM chain observer - observer, err := evmobserver.NewObserver( - ctx, - evmConfig, - evmClient, - *chainParams, - zetacoreClient, - tss, - dbpath, - logger, - ts, - ) - if err != nil { - logger.Std.Error().Err(err).Msgf("NewObserver error for evm chain %s", evmConfig.Chain.String()) - continue - } - observerMap[evmConfig.Chain.ChainId] = observer - } - - // BTC observer - _, btcChainParams, found := appContext.GetBTCChainParams() - if !found { - return nil, fmt.Errorf("bitcoin chains params not found") - } - - // create BTC chain observer - btcChain, btcConfig, enabled := appContext.GetBTCChainAndConfig() - if enabled { - btcClient, err := btcrpc.NewRPCClient(btcConfig) - if err != nil { - logger.Std.Error().Err(err).Msgf("error creating rpc client for bitcoin chain %s", btcChain.String()) - } else { - // create BTC chain observer - observer, err := btcobserver.NewObserver( - btcChain, - btcClient, - *btcChainParams, - zetacoreClient, - tss, - dbpath, - logger, - ts, - ) - if err != nil { - logger.Std.Error().Err(err).Msgf("NewObserver error for bitcoin chain %s", btcChain.String()) - } else { - observerMap[btcChain.ChainId] = observer - } - } - } - - // Solana chain params - _, solChainParams, found := appContext.GetSolanaChainParams() - if !found { - logger.Std.Error().Msg("solana chain params not found") - return observerMap, nil - } - - // create Solana chain observer - solChain, solConfig, enabled := appContext.GetSolanaChainAndConfig() - if enabled { - rpcClient := solrpc.New(solConfig.Endpoint) - if rpcClient == nil { - // should never happen - logger.Std.Error().Msg("solana create Solana client error") - return observerMap, nil - } - - observer, err := solanaobserver.NewObserver( - solChain, - rpcClient, - *solChainParams, - zetacoreClient, - tss, - dbpath, - logger, - ts, - ) - if err != nil { - logger.Std.Error().Err(err).Msg("NewObserver error for solana chain") - } else { - observerMap[solChainParams.ChainId] = observer - } - } - - return observerMap, nil -} +// TODO diff --git a/pkg/ptr/ptr.go b/pkg/ptr/ptr.go new file mode 100644 index 0000000000..baab694c81 --- /dev/null +++ b/pkg/ptr/ptr.go @@ -0,0 +1,17 @@ +// Package ptr provides helper functions for working with pointers. +package ptr + +// Ptr returns a pointer to the value passed in. +func Ptr[T any](value T) *T { + return &value +} + +// Deref returns the value of the pointer passed in, or the zero value of the type if the pointer is nil. +func Deref[T any](value *T) T { + var out T + if value != nil { + out = *value + } + + return out +} diff --git a/zetaclient/chains/base/observer.go b/zetaclient/chains/base/observer.go index b3eb554a96..90a190a020 100644 --- a/zetaclient/chains/base/observer.go +++ b/zetaclient/chains/base/observer.go @@ -5,21 +5,18 @@ import ( "fmt" "os" "strconv" - "strings" "sync" "sync/atomic" lru "github.com/hashicorp/golang-lru" "github.com/pkg/errors" "github.com/rs/zerolog" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" "github.com/zeta-chain/zetacore/pkg/chains" crosschaintypes "github.com/zeta-chain/zetacore/x/crosschain/types" observertypes "github.com/zeta-chain/zetacore/x/observer/types" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/metrics" clienttypes "github.com/zeta-chain/zetacore/zetaclient/types" "github.com/zeta-chain/zetacore/zetaclient/zetacore" @@ -69,7 +66,7 @@ type Observer struct { headerCache *lru.Cache // db is the database to persist data - db *gorm.DB + db *db.DB // ts is the telemetry server for metrics ts *metrics.TelemetryServer @@ -79,7 +76,8 @@ type Observer struct { // mu protects fields from concurrent access // Note: base observer simply provides the mutex. It's the sub-struct's responsibility to use it to be thread-safe - mu *sync.Mutex + mu *sync.Mutex + started bool // stop is the channel to signal the observer to stop stop chan struct{} @@ -94,6 +92,7 @@ func NewObserver( blockCacheSize int, headerCacheSize int, ts *metrics.TelemetryServer, + database *db.DB, logger Logger, ) (*Observer, error) { ob := Observer{ @@ -105,6 +104,7 @@ func NewObserver( lastBlockScanned: 0, lastTxScanned: "", ts: ts, + db: database, mu: &sync.Mutex{}, stop: make(chan struct{}), } @@ -128,18 +128,41 @@ func NewObserver( return &ob, nil } +// Start starts the observer. Returns true if the observer was already started (noop). +func (ob *Observer) Start() bool { + ob.mu.Lock() + defer ob.Mu().Unlock() + + // noop + if ob.started { + return true + } + + ob.started = true + + return false +} + // Stop notifies all goroutines to stop and closes the database. func (ob *Observer) Stop() { - ob.logger.Chain.Info().Msgf("observer is stopping for chain %d", ob.Chain().ChainId) + ob.mu.Lock() + defer ob.mu.Unlock() + + if !ob.started { + ob.logger.Chain.Info().Msgf("Observer already stopped for chain %d", ob.Chain().ChainId) + return + } + + ob.logger.Chain.Info().Msgf("Stopping observer for chain %d", ob.Chain().ChainId) + close(ob.stop) + ob.started = false // close database - if ob.db != nil { - err := ob.CloseDB() - if err != nil { - ob.Logger().Chain.Error().Err(err).Msgf("CloseDB failed for chain %d", ob.Chain().ChainId) - } + if err := ob.db.Close(); err != nil { + ob.Logger().Chain.Error().Err(err).Msgf("unable to close db for chain %d", ob.Chain().ChainId) } + ob.Logger().Chain.Info().Msgf("observer stopped for chain %d", ob.Chain().ChainId) } @@ -245,7 +268,7 @@ func (ob *Observer) WithHeaderCache(cache *lru.Cache) *Observer { } // DB returns the database for the observer. -func (ob *Observer) DB() *gorm.DB { +func (ob *Observer) DB() *db.DB { return ob.db } @@ -289,59 +312,6 @@ func (ob *Observer) StopChannel() chan struct{} { return ob.stop } -// OpenDB open sql database in the given path. -func (ob *Observer) OpenDB(dbPath string, dbName string) error { - // create db path if not exist - if _, err := os.Stat(dbPath); os.IsNotExist(err) { - err := os.MkdirAll(dbPath, 0o750) - if err != nil { - return errors.Wrapf(err, "error creating db path: %s", dbPath) - } - } - - // use custom dbName or chain name if not provided - if dbName == "" { - dbName = ob.chain.ChainName.String() - } - path := fmt.Sprintf("%s/%s", dbPath, dbName) - - // use memory db if specified - if strings.Contains(dbPath, ":memory:") { - path = dbPath - } - - // open db - db, err := gorm.Open(sqlite.Open(path), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) - if err != nil { - return errors.Wrap(err, "error opening db") - } - - // migrate db - err = db.AutoMigrate( - &clienttypes.LastBlockSQLType{}, - &clienttypes.LastTransactionSQLType{}, - ) - if err != nil { - return errors.Wrap(err, "error migrating db") - } - ob.db = db - - return nil -} - -// CloseDB close the database. -func (ob *Observer) CloseDB() error { - dbInst, err := ob.db.DB() - if err != nil { - return fmt.Errorf("error getting database instance: %w", err) - } - err = dbInst.Close() - if err != nil { - return fmt.Errorf("error closing database: %w", err) - } - return nil -} - // LoadLastBlockScanned loads last scanned block from environment variable or from database. // The last scanned block is the height from which the observer should continue scanning. func (ob *Observer) LoadLastBlockScanned(logger zerolog.Logger) error { @@ -358,7 +328,7 @@ func (ob *Observer) LoadLastBlockScanned(logger zerolog.Logger) error { } blockNumber, err := strconv.ParseUint(scanFromBlock, 10, 64) if err != nil { - return err + return errors.Wrapf(err, "unable to parse block number from ENV %s=%s", envvar, scanFromBlock) } ob.WithLastBlockScanned(blockNumber) return nil @@ -383,13 +353,13 @@ func (ob *Observer) SaveLastBlockScanned(blockNumber uint64) error { // WriteLastBlockScannedToDB saves the last scanned block to the database. func (ob *Observer) WriteLastBlockScannedToDB(lastScannedBlock uint64) error { - return ob.db.Save(clienttypes.ToLastBlockSQLType(lastScannedBlock)).Error + return ob.db.Client().Save(clienttypes.ToLastBlockSQLType(lastScannedBlock)).Error } // ReadLastBlockScannedFromDB reads the last scanned block from the database. func (ob *Observer) ReadLastBlockScannedFromDB() (uint64, error) { var lastBlock clienttypes.LastBlockSQLType - if err := ob.db.First(&lastBlock, clienttypes.LastBlockNumID).Error; err != nil { + if err := ob.db.Client().First(&lastBlock, clienttypes.LastBlockNumID).Error; err != nil { // record not found return 0, err } @@ -433,13 +403,13 @@ func (ob *Observer) SaveLastTxScanned(txHash string, slot uint64) error { // WriteLastTxScannedToDB saves the last scanned tx hash to the database. func (ob *Observer) WriteLastTxScannedToDB(txHash string) error { - return ob.db.Save(clienttypes.ToLastTxHashSQLType(txHash)).Error + return ob.db.Client().Save(clienttypes.ToLastTxHashSQLType(txHash)).Error } // ReadLastTxScannedFromDB reads the last scanned tx hash from the database. func (ob *Observer) ReadLastTxScannedFromDB() (string, error) { var lastTx clienttypes.LastTransactionSQLType - if err := ob.db.First(&lastTx, clienttypes.LastTxHashID).Error; err != nil { + if err := ob.db.Client().First(&lastTx, clienttypes.LastTxHashID).Error; err != nil { // record not found return "", err } diff --git a/zetaclient/chains/base/observer_test.go b/zetaclient/chains/base/observer_test.go index cd3ce26374..b40802c0a7 100644 --- a/zetaclient/chains/base/observer_test.go +++ b/zetaclient/chains/base/observer_test.go @@ -9,8 +9,6 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" - "github.com/zeta-chain/zetacore/zetaclient/testutils" - "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/coin" "github.com/zeta-chain/zetacore/testutil/sample" @@ -19,6 +17,7 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" "github.com/zeta-chain/zetacore/zetaclient/config" zctx "github.com/zeta-chain/zetacore/zetaclient/context" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/metrics" "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" ) @@ -30,6 +29,8 @@ func createObserver(t *testing.T, chain chains.Chain) *base.Observer { zetacoreClient := mocks.NewZetacoreClient(t) tss := mocks.NewTSSMainnet() + database := createDatabase(t) + // create observer logger := base.DefaultLogger() ob, err := base.NewObserver( @@ -40,6 +41,7 @@ func createObserver(t *testing.T, chain chains.Chain) *base.Observer { base.DefaultBlockCacheSize, base.DefaultHeaderCacheSize, nil, + database, logger, ) require.NoError(t, err) @@ -57,6 +59,8 @@ func TestNewObserver(t *testing.T) { blockCacheSize := base.DefaultBlockCacheSize headersCacheSize := base.DefaultHeaderCacheSize + database := createDatabase(t) + // test cases tests := []struct { name string @@ -118,6 +122,7 @@ func TestNewObserver(t *testing.T) { tt.blockCacheSize, tt.headerCacheSize, nil, + database, base.DefaultLogger(), ) if tt.fail { @@ -136,7 +141,6 @@ func TestStop(t *testing.T) { t.Run("should be able to stop observer", func(t *testing.T) { // create observer and initialize db ob := createObserver(t, chains.Ethereum) - ob.OpenDB(sample.CreateTempDir(t), "") // stop observer ob.Stop() @@ -145,6 +149,7 @@ func TestStop(t *testing.T) { func TestObserverGetterAndSetter(t *testing.T) { chain := chains.Ethereum + t.Run("should be able to update chain", func(t *testing.T) { ob := createObserver(t, chain) @@ -221,15 +226,6 @@ func TestObserverGetterAndSetter(t *testing.T) { ob = ob.WithHeaderCache(newHeadersCache) require.Equal(t, newHeadersCache, ob.HeaderCache()) }) - t.Run("should be able to get database", func(t *testing.T) { - // create observer and open db - dbPath := sample.CreateTempDir(t) - ob := createObserver(t, chain) - ob.OpenDB(dbPath, "") - - db := ob.DB() - require.NotNil(t, db) - }) t.Run("should be able to update telemetry server", func(t *testing.T) { ob := createObserver(t, chain) @@ -252,47 +248,17 @@ func TestObserverGetterAndSetter(t *testing.T) { }) } -func TestOpenCloseDB(t *testing.T) { - dbPath := sample.CreateTempDir(t) - ob := createObserver(t, chains.Ethereum) - - t.Run("should be able to open/close db", func(t *testing.T) { - // open db - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) - - // close db - err = ob.CloseDB() - require.NoError(t, err) - }) - t.Run("should use memory db if specified", func(t *testing.T) { - // open db with memory - err := ob.OpenDB(testutils.SQLiteMemory, "") - require.NoError(t, err) - - // close db - err = ob.CloseDB() - require.NoError(t, err) - }) - t.Run("should return error on invalid db path", func(t *testing.T) { - err := ob.OpenDB("/invalid/123db", "") - require.ErrorContains(t, err, "error creating db path") - }) -} - func TestLoadLastBlockScanned(t *testing.T) { chain := chains.Ethereum envvar := base.EnvVarLatestBlockByChain(chain) t.Run("should be able to load last block scanned", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // create db and write 100 as last block scanned - ob.WriteLastBlockScannedToDB(100) + err := ob.WriteLastBlockScannedToDB(100) + require.NoError(t, err) // read last block scanned err = ob.LoadLastBlockScanned(log.Logger) @@ -301,22 +267,16 @@ func TestLoadLastBlockScanned(t *testing.T) { }) t.Run("latest block scanned should be 0 if not found in db", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // read last block scanned - err = ob.LoadLastBlockScanned(log.Logger) + err := ob.LoadLastBlockScanned(log.Logger) require.NoError(t, err) require.EqualValues(t, 0, ob.LastBlockScanned()) }) t.Run("should overwrite last block scanned if env var is set", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // create db and write 100 as last block scanned ob.WriteLastBlockScannedToDB(100) @@ -325,16 +285,13 @@ func TestLoadLastBlockScanned(t *testing.T) { os.Setenv(envvar, "101") // read last block scanned - err = ob.LoadLastBlockScanned(log.Logger) + err := ob.LoadLastBlockScanned(log.Logger) require.NoError(t, err) require.EqualValues(t, 101, ob.LastBlockScanned()) }) t.Run("last block scanned should remain 0 if env var is set to latest", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // create db and write 100 as last block scanned ob.WriteLastBlockScannedToDB(100) @@ -343,22 +300,19 @@ func TestLoadLastBlockScanned(t *testing.T) { os.Setenv(envvar, base.EnvVarLatestBlock) // last block scanned should remain 0 - err = ob.LoadLastBlockScanned(log.Logger) + err := ob.LoadLastBlockScanned(log.Logger) require.NoError(t, err) require.EqualValues(t, 0, ob.LastBlockScanned()) }) t.Run("should return error on invalid env var", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // set invalid env var os.Setenv(envvar, "invalid") // read last block scanned - err = ob.LoadLastBlockScanned(log.Logger) + err := ob.LoadLastBlockScanned(log.Logger) require.Error(t, err) }) } @@ -366,13 +320,10 @@ func TestLoadLastBlockScanned(t *testing.T) { func TestSaveLastBlockScanned(t *testing.T) { t.Run("should be able to save last block scanned", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chains.Ethereum) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // save 100 as last block scanned - err = ob.SaveLastBlockScanned(100) + err := ob.SaveLastBlockScanned(100) require.NoError(t, err) // check last block scanned in memory @@ -389,13 +340,10 @@ func TestReadWriteDBLastBlockScanned(t *testing.T) { chain := chains.Ethereum t.Run("should be able to write and read last block scanned to db", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // write last block scanned - err = ob.WriteLastBlockScannedToDB(100) + err := ob.WriteLastBlockScannedToDB(100) require.NoError(t, err) lastBlockScanned, err := ob.ReadLastBlockScannedFromDB() @@ -404,17 +352,13 @@ func TestReadWriteDBLastBlockScanned(t *testing.T) { }) t.Run("should return error when last block scanned not found in db", func(t *testing.T) { // create empty db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) lastScannedBlock, err := ob.ReadLastBlockScannedFromDB() require.Error(t, err) require.Zero(t, lastScannedBlock) }) } - func TestLoadLastTxScanned(t *testing.T) { chain := chains.SolanaDevnet envvar := base.EnvVarLatestTxByChain(chain) @@ -422,37 +366,26 @@ func TestLoadLastTxScanned(t *testing.T) { t.Run("should be able to load last tx scanned", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // create db and write sample hash as last tx scanned ob.WriteLastTxScannedToDB(lastTx) // read last tx scanned ob.LoadLastTxScanned() - require.NoError(t, err) require.EqualValues(t, lastTx, ob.LastTxScanned()) }) t.Run("latest tx scanned should be empty if not found in db", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // read last tx scanned ob.LoadLastTxScanned() - require.NoError(t, err) require.Empty(t, ob.LastTxScanned()) }) t.Run("should overwrite last tx scanned if env var is set", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // create db and write sample hash as last tx scanned ob.WriteLastTxScannedToDB(lastTx) @@ -463,7 +396,6 @@ func TestLoadLastTxScanned(t *testing.T) { // read last block scanned ob.LoadLastTxScanned() - require.NoError(t, err) require.EqualValues(t, otherTx, ob.LastTxScanned()) }) } @@ -472,15 +404,12 @@ func TestSaveLastTxScanned(t *testing.T) { chain := chains.SolanaDevnet t.Run("should be able to save last tx scanned", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // save random tx hash lastSlot := uint64(100) lastTx := "5LuQMorgd11p8GWEw6pmyHCDtA26NUyeNFhLWPNk2oBoM9pkag1LzhwGSRos3j4TJLhKjswFhZkGtvSGdLDkmqsk" - err = ob.SaveLastTxScanned(lastTx, lastSlot) + err := ob.SaveLastTxScanned(lastTx, lastSlot) require.NoError(t, err) // check last tx and slot scanned in memory @@ -498,14 +427,11 @@ func TestReadWriteDBLastTxScanned(t *testing.T) { chain := chains.SolanaDevnet t.Run("should be able to write and read last tx scanned to db", func(t *testing.T) { // create observer and open db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) // write last tx scanned lastTx := "5LuQMorgd11p8GWEw6pmyHCDtA26NUyeNFhLWPNk2oBoM9pkag1LzhwGSRos3j4TJLhKjswFhZkGtvSGdLDkmqsk" - err = ob.WriteLastTxScannedToDB(lastTx) + err := ob.WriteLastTxScannedToDB(lastTx) require.NoError(t, err) lastTxScanned, err := ob.ReadLastTxScannedFromDB() @@ -514,10 +440,7 @@ func TestReadWriteDBLastTxScanned(t *testing.T) { }) t.Run("should return error when last tx scanned not found in db", func(t *testing.T) { // create empty db - dbPath := sample.CreateTempDir(t) ob := createObserver(t, chain) - err := ob.OpenDB(dbPath, "") - require.NoError(t, err) lastTxScanned, err := ob.ReadLastTxScannedFromDB() require.Error(t, err) @@ -542,3 +465,10 @@ func TestPostVoteInbound(t *testing.T) { require.Equal(t, "sampleBallotIndex", ballot) }) } + +func createDatabase(t *testing.T) *db.DB { + sqlDatabase, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + + return sqlDatabase +} diff --git a/zetaclient/chains/bitcoin/observer/observer.go b/zetaclient/chains/bitcoin/observer/observer.go index 2aea6e450f..4b07974052 100644 --- a/zetaclient/chains/bitcoin/observer/observer.go +++ b/zetaclient/chains/bitcoin/observer/observer.go @@ -27,6 +27,7 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin" "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/rpc" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/metrics" clienttypes "github.com/zeta-chain/zetacore/zetaclient/types" ) @@ -116,7 +117,7 @@ func NewObserver( chainParams observertypes.ChainParams, zetacoreClient interfaces.ZetacoreClient, tss interfaces.TSSSigner, - dbpath string, + database *db.DB, logger base.Logger, ts *metrics.TelemetryServer, ) (*Observer, error) { @@ -129,16 +130,17 @@ func NewObserver( btcBlocksPerDay, base.DefaultHeaderCacheSize, ts, + database, logger, ) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "unable to create base observer for chain %d", chain.ChainId) } // get the bitcoin network params netParams, err := chains.BitcoinNetParamsFromChainID(chain.ChainId) if err != nil { - return nil, fmt.Errorf("error getting net params for chain %d: %s", chain.ChainId, err) + return nil, errors.Wrapf(err, "unable to get BTC net params for chain %d", chain.ChainId) } // create bitcoin observer @@ -157,10 +159,14 @@ func NewObserver( }, } - // load btc chain observer DB - err = ob.LoadDB(dbpath) - if err != nil { - return nil, err + // load last scanned block + if err := ob.LoadLastBlockScanned(); err != nil { + return nil, errors.Wrap(err, "unable to load last scanned block") + } + + // load broadcasted transactions + if err := ob.LoadBroadcastedTxMap(); err != nil { + return nil, errors.Wrap(err, "unable to load broadcasted tx map") } return ob, nil @@ -194,6 +200,11 @@ func (ob *Observer) GetChainParams() observertypes.ChainParams { // Start starts the Go routine processes to observe the Bitcoin chain func (ob *Observer) Start(ctx context.Context) { + if noop := ob.Observer.Start(); noop { + ob.Logger().Chain.Info().Msgf("observer is already started for chain %d", ob.Chain().ChainId) + return + } + ob.Logger().Chain.Info().Msgf("observer is starting for chain %d", ob.Chain().ChainId) // watch bitcoin chain for incoming txs and post votes to zetacore @@ -529,7 +540,7 @@ func (ob *Observer) SaveBroadcastedTx(txHash string, nonce uint64) { ob.Mu().Unlock() broadcastEntry := clienttypes.ToOutboundHashSQLType(txHash, outboundID) - if err := ob.DB().Save(&broadcastEntry).Error; err != nil { + if err := ob.DB().Client().Save(&broadcastEntry).Error; err != nil { ob.logger.Outbound.Error(). Err(err). Msgf("SaveBroadcastedTx: error saving broadcasted txHash %s for outbound %s", txHash, outboundID) @@ -570,39 +581,6 @@ func (ob *Observer) GetBlockByNumberCached(blockNumber int64) (*BTCBlockNHeader, return blockNheader, nil } -// LoadDB open sql database and load data into Bitcoin observer -func (ob *Observer) LoadDB(dbPath string) error { - if dbPath == "" { - return errors.New("empty db path") - } - - // open database, the custom dbName is used here for backward compatibility - err := ob.OpenDB(dbPath, "btc_chain_client") - if err != nil { - return errors.Wrapf(err, "error OpenDB for chain %d", ob.Chain().ChainId) - } - - // run auto migration - // transaction result table is used nowhere but we still run migration in case they are needed in future - err = ob.DB().AutoMigrate( - &clienttypes.TransactionResultSQLType{}, - &clienttypes.OutboundHashSQLType{}, - ) - if err != nil { - return errors.Wrapf(err, "error AutoMigrate for chain %d", ob.Chain().ChainId) - } - - // load last scanned block - err = ob.LoadLastBlockScanned() - if err != nil { - return err - } - - // load broadcasted transactions - err = ob.LoadBroadcastedTxMap() - return err -} - // LoadLastBlockScanned loads the last scanned block from the database func (ob *Observer) LoadLastBlockScanned() error { err := ob.Observer.LoadLastBlockScanned(ob.Logger().Chain) @@ -634,7 +612,7 @@ func (ob *Observer) LoadLastBlockScanned() error { // LoadBroadcastedTxMap loads broadcasted transactions from the database func (ob *Observer) LoadBroadcastedTxMap() error { var broadcastedTransactions []clienttypes.OutboundHashSQLType - if err := ob.DB().Find(&broadcastedTransactions).Error; err != nil { + if err := ob.DB().Client().Find(&broadcastedTransactions).Error; err != nil { ob.logger.Chain.Error().Err(err).Msgf("error iterating over db for chain %d", ob.Chain().ChainId) return err } diff --git a/zetaclient/chains/bitcoin/observer/observer_test.go b/zetaclient/chains/bitcoin/observer/observer_test.go index 438324b091..c873dfb8d7 100644 --- a/zetaclient/chains/bitcoin/observer/observer_test.go +++ b/zetaclient/chains/bitcoin/observer/observer_test.go @@ -11,8 +11,7 @@ import ( "github.com/btcsuite/btcd/wire" lru "github.com/hashicorp/golang-lru" "github.com/stretchr/testify/require" - "github.com/zeta-chain/zetacore/zetaclient/testutils" - "gorm.io/driver/sqlite" + "github.com/zeta-chain/zetacore/zetaclient/db" "gorm.io/gorm" "github.com/zeta-chain/zetacore/pkg/chains" @@ -35,10 +34,7 @@ var ( func setupDBTxResults(t *testing.T) (*gorm.DB, map[string]btcjson.GetTransactionResult) { submittedTx := map[string]btcjson.GetTransactionResult{} - db, err := gorm.Open(sqlite.Open(testutils.SQLiteMemory), &gorm.Config{}) - require.NoError(t, err) - - err = db.AutoMigrate(&clienttypes.TransactionResultSQLType{}) + database, err := db.NewFromSqliteInMemory(true) require.NoError(t, err) //Create some Transaction entries in the DB @@ -58,12 +54,12 @@ func setupDBTxResults(t *testing.T) (*gorm.DB, map[string]btcjson.GetTransaction Hex: "", } r, _ := clienttypes.ToTransactionResultSQLType(txResult, strconv.Itoa(i)) - dbc := db.Create(&r) + dbc := database.Client().Create(&r) require.NoError(t, dbc.Error) submittedTx[strconv.Itoa(i)] = txResult } - return db, submittedTx + return database.Client(), submittedTx } // MockBTCObserver creates a mock Bitcoin observer for testing @@ -72,17 +68,13 @@ func MockBTCObserver( chain chains.Chain, params observertypes.ChainParams, btcClient interfaces.BTCRPCClient, - dbpath string, ) *observer.Observer { // use default mock btc client if not provided if btcClient == nil { btcClient = mocks.NewMockBTCRPCClient().WithBlockCount(100) } - // use memory db if dbpath is empty - if dbpath == "" { - dbpath = "file::memory:?cache=shared" - } + database, err := db.NewFromSqliteInMemory(true) // create observer ob, err := observer.NewObserver( @@ -91,7 +83,7 @@ func MockBTCObserver( params, nil, nil, - dbpath, + database, base.Logger{}, nil, ) @@ -107,17 +99,17 @@ func Test_NewObserver(t *testing.T) { // test cases tests := []struct { - name string - chain chains.Chain - btcClient interfaces.BTCRPCClient - chainParams observertypes.ChainParams - coreClient interfaces.ZetacoreClient - tss interfaces.TSSSigner - dbpath string - logger base.Logger - ts *metrics.TelemetryServer - fail bool - message string + name string + chain chains.Chain + btcClient interfaces.BTCRPCClient + chainParams observertypes.ChainParams + coreClient interfaces.ZetacoreClient + tss interfaces.TSSSigner + logger base.Logger + ts *metrics.TelemetryServer + errorMessage string + before func() + after func() }{ { name: "should be able to create observer", @@ -126,42 +118,50 @@ func Test_NewObserver(t *testing.T) { chainParams: params, coreClient: nil, tss: mocks.NewTSSMainnet(), - dbpath: sample.CreateTempDir(t), - logger: base.Logger{}, - ts: nil, - fail: false, }, { - name: "should fail if net params is not found", - chain: chains.Chain{ChainId: 111}, // invalid chain id - btcClient: mocks.NewMockBTCRPCClient().WithBlockCount(100), - chainParams: params, - coreClient: nil, - tss: mocks.NewTSSMainnet(), - dbpath: sample.CreateTempDir(t), - logger: base.Logger{}, - ts: nil, - fail: true, - message: "error getting net params", + name: "should fail if net params is not found", + chain: chains.Chain{ChainId: 111}, // invalid chain id + btcClient: mocks.NewMockBTCRPCClient().WithBlockCount(100), + chainParams: params, + coreClient: nil, + tss: mocks.NewTSSMainnet(), + errorMessage: "unable to get BTC net params for chain", }, { - name: "should fail on invalid dbpath", + name: "should fail if env var us invalid", chain: chain, + btcClient: mocks.NewMockBTCRPCClient().WithBlockCount(100), chainParams: params, coreClient: nil, - btcClient: mocks.NewMockBTCRPCClient().WithBlockCount(100), tss: mocks.NewTSSMainnet(), - dbpath: "/invalid/dbpath", // invalid dbpath - logger: base.Logger{}, - ts: nil, - fail: true, - message: "error creating db path", + before: func() { + envVar := base.EnvVarLatestBlockByChain(chain) + os.Setenv(envVar, "invalid") + }, + after: func() { + envVar := base.EnvVarLatestBlockByChain(chain) + os.Unsetenv(envVar) + }, + errorMessage: "unable to parse block number from ENV", }, } // run tests for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // create db + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + + if tt.before != nil { + tt.before() + } + + if tt.after != nil { + defer tt.after() + } + // create observer ob, err := observer.NewObserver( tt.chain, @@ -169,19 +169,19 @@ func Test_NewObserver(t *testing.T) { tt.chainParams, tt.coreClient, tt.tss, - tt.dbpath, + database, tt.logger, tt.ts, ) - // check result - if tt.fail { - require.ErrorContains(t, err, tt.message) + if tt.errorMessage != "" { + require.ErrorContains(t, err, tt.errorMessage) require.Nil(t, ob) - } else { - require.NoError(t, err) - require.NotNil(t, ob) + return } + + require.NoError(t, err) + require.NotNil(t, ob) }) } } @@ -236,46 +236,6 @@ func Test_BlockCache(t *testing.T) { }) } -func Test_LoadDB(t *testing.T) { - // use Bitcoin mainnet chain for testing - chain := chains.BitcoinMainnet - params := mocks.MockChainParams(chain.ChainId, 10) - - // create mock btc client, tss and test dbpath - btcClient := mocks.NewMockBTCRPCClient().WithBlockCount(100) - tss := mocks.NewTSSMainnet() - - // create observer - dbpath := sample.CreateTempDir(t) - ob, err := observer.NewObserver(chain, btcClient, params, nil, tss, dbpath, base.Logger{}, nil) - require.NoError(t, err) - - t.Run("should load db successfully", func(t *testing.T) { - err := ob.LoadDB(dbpath) - require.NoError(t, err) - require.EqualValues(t, 100, ob.LastBlockScanned()) - }) - t.Run("should fail on invalid dbpath", func(t *testing.T) { - // load db with empty dbpath - err := ob.LoadDB("") - require.ErrorContains(t, err, "empty db path") - - // load db with invalid dbpath - err = ob.LoadDB("/invalid/dbpath") - require.ErrorContains(t, err, "error OpenDB") - }) - t.Run("should fail on invalid env var", func(t *testing.T) { - // set invalid environment variable - envvar := base.EnvVarLatestBlockByChain(chain) - os.Setenv(envvar, "invalid") - defer os.Unsetenv(envvar) - - // load db - err := ob.LoadDB(dbpath) - require.ErrorContains(t, err, "error LoadLastBlockScanned") - }) -} - func Test_LoadLastBlockScanned(t *testing.T) { // use Bitcoin mainnet chain for testing chain := chains.BitcoinMainnet @@ -283,11 +243,10 @@ func Test_LoadLastBlockScanned(t *testing.T) { // create observer using mock btc client btcClient := mocks.NewMockBTCRPCClient().WithBlockCount(200) - dbpath := sample.CreateTempDir(t) t.Run("should load last block scanned", func(t *testing.T) { // create observer and write 199 as last block scanned - ob := MockBTCObserver(t, chain, params, btcClient, dbpath) + ob := MockBTCObserver(t, chain, params, btcClient) ob.WriteLastBlockScannedToDB(199) // load last block scanned @@ -297,7 +256,7 @@ func Test_LoadLastBlockScanned(t *testing.T) { }) t.Run("should fail on invalid env var", func(t *testing.T) { // create observer - ob := MockBTCObserver(t, chain, params, btcClient, dbpath) + ob := MockBTCObserver(t, chain, params, btcClient) // set invalid environment variable envvar := base.EnvVarLatestBlockByChain(chain) @@ -310,8 +269,7 @@ func Test_LoadLastBlockScanned(t *testing.T) { }) t.Run("should fail on RPC error", func(t *testing.T) { // create observer on separate path, as we need to reset last block scanned - otherPath := sample.CreateTempDir(t) - obOther := MockBTCObserver(t, chain, params, btcClient, otherPath) + obOther := MockBTCObserver(t, chain, params, btcClient) // reset last block scanned to 0 so that it will be loaded from RPC obOther.WithLastBlockScanned(0) @@ -326,7 +284,7 @@ func Test_LoadLastBlockScanned(t *testing.T) { t.Run("should use hardcode block 100 for regtest", func(t *testing.T) { // use regtest chain regtest := chains.BitcoinRegtest - obRegnet := MockBTCObserver(t, regtest, params, btcClient, dbpath) + obRegnet := MockBTCObserver(t, regtest, params, nil) // load last block scanned err := obRegnet.LoadLastBlockScanned() @@ -338,7 +296,7 @@ func Test_LoadLastBlockScanned(t *testing.T) { func TestConfirmationThreshold(t *testing.T) { chain := chains.BitcoinMainnet params := mocks.MockChainParams(chain.ChainId, 10) - ob := MockBTCObserver(t, chain, params, nil, "") + ob := MockBTCObserver(t, chain, params, nil) t.Run("should return confirmations in chain param", func(t *testing.T) { ob.SetChainParams(observertypes.ChainParams{ConfirmationCount: 3}) diff --git a/zetaclient/chains/bitcoin/observer/outbound_test.go b/zetaclient/chains/bitcoin/observer/outbound_test.go index cb43590ff5..d661b1c7bb 100644 --- a/zetaclient/chains/bitcoin/observer/outbound_test.go +++ b/zetaclient/chains/bitcoin/observer/outbound_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcjson" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/zetaclient/chains/base" @@ -27,8 +28,11 @@ func MockBTCObserverMainnet(t *testing.T) *Observer { params := mocks.MockChainParams(chain.ChainId, 10) tss := mocks.NewTSSMainnet() + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + // create Bitcoin observer - ob, err := NewObserver(chain, btcClient, params, nil, tss, testutils.SQLiteMemory, base.Logger{}, nil) + ob, err := NewObserver(chain, btcClient, params, nil, tss, database, base.Logger{}, nil) require.NoError(t, err) return ob diff --git a/zetaclient/chains/bitcoin/rpc/rpc_live_test.go b/zetaclient/chains/bitcoin/rpc/rpc_live_test.go index 7cc0abc11d..54964d7403 100644 --- a/zetaclient/chains/bitcoin/rpc/rpc_live_test.go +++ b/zetaclient/chains/bitcoin/rpc/rpc_live_test.go @@ -19,6 +19,7 @@ import ( "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/zetaclient/chains/base" @@ -54,9 +55,12 @@ func (suite *BitcoinObserverTestSuite) SetupTest() { params := mocks.MockChainParams(chain.ChainId, 10) btcClient := mocks.NewMockBTCRPCClient() + database, err := db.NewFromSqliteInMemory(true) + suite.Require().NoError(err) + // create observer - ob, err := observer.NewObserver(chain, btcClient, params, nil, tss, testutils.SQLiteMemory, - base.DefaultLogger(), nil) + ob, err := observer.NewObserver(chain, btcClient, params, nil, tss, database, base.DefaultLogger(), nil) + suite.Require().NoError(err) suite.Require().NotNil(ob) suite.rpcClient, err = createRPCClient(18332) diff --git a/zetaclient/chains/bitcoin/signer/signer.go b/zetaclient/chains/bitcoin/signer/signer.go index a6a5a3e51d..7019c16755 100644 --- a/zetaclient/chains/bitcoin/signer/signer.go +++ b/zetaclient/chains/bitcoin/signer/signer.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" ethcommon "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/coin" @@ -77,7 +78,7 @@ func NewSigner( } client, err := rpcclient.New(connCfg, nil) if err != nil { - return nil, fmt.Errorf("error creating bitcoin rpc client: %s", err) + return nil, errors.Wrap(err, "unable to create bitcoin rpc client") } return &Signer{ diff --git a/zetaclient/chains/evm/observer/inbound_test.go b/zetaclient/chains/evm/observer/inbound_test.go index de1e003ab8..9e01c214b3 100644 --- a/zetaclient/chains/evm/observer/inbound_test.go +++ b/zetaclient/chains/evm/observer/inbound_test.go @@ -45,7 +45,7 @@ func Test_CheckAndVoteInboundTokenZeta(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenZeta(ctx, tx, receipt, false) require.NoError(t, err) require.Equal(t, cctx.InboundParams.BallotIndex, ballot) @@ -61,7 +61,7 @@ func Test_CheckAndVoteInboundTokenZeta(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - 1 - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) _, err := ob.CheckAndVoteInboundTokenZeta(ctx, tx, receipt, false) require.ErrorContains(t, err, "not been confirmed") }) @@ -77,7 +77,7 @@ func Test_CheckAndVoteInboundTokenZeta(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenZeta(ctx, tx, receipt, true) require.NoError(t, err) require.Equal(t, "", ballot) @@ -101,7 +101,6 @@ func Test_CheckAndVoteInboundTokenZeta(t *testing.T) { nil, nil, nil, - memDBPath, lastBlock, mocks.MockChainParams(chainID, confirmation), ) @@ -132,7 +131,7 @@ func Test_CheckAndVoteInboundTokenERC20(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenERC20(ctx, tx, receipt, false) require.NoError(t, err) require.Equal(t, cctx.InboundParams.BallotIndex, ballot) @@ -148,7 +147,7 @@ func Test_CheckAndVoteInboundTokenERC20(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - 1 - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) _, err := ob.CheckAndVoteInboundTokenERC20(ctx, tx, receipt, false) require.ErrorContains(t, err, "not been confirmed") }) @@ -164,7 +163,7 @@ func Test_CheckAndVoteInboundTokenERC20(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenERC20(ctx, tx, receipt, true) require.NoError(t, err) require.Equal(t, "", ballot) @@ -188,7 +187,6 @@ func Test_CheckAndVoteInboundTokenERC20(t *testing.T) { nil, nil, nil, - memDBPath, lastBlock, mocks.MockChainParams(chainID, confirmation), ) @@ -219,7 +217,7 @@ func Test_CheckAndVoteInboundTokenGas(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenGas(ctx, tx, receipt, false) require.NoError(t, err) require.Equal(t, cctx.InboundParams.BallotIndex, ballot) @@ -229,7 +227,7 @@ func Test_CheckAndVoteInboundTokenGas(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - 1 - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) _, err := ob.CheckAndVoteInboundTokenGas(ctx, tx, receipt, false) require.ErrorContains(t, err, "not been confirmed") }) @@ -239,7 +237,7 @@ func Test_CheckAndVoteInboundTokenGas(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenGas(ctx, tx, receipt, false) require.ErrorContains(t, err, "not TSS address") require.Equal(t, "", ballot) @@ -250,7 +248,7 @@ func Test_CheckAndVoteInboundTokenGas(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenGas(ctx, tx, receipt, false) require.ErrorContains(t, err, "not a successful tx") require.Equal(t, "", ballot) @@ -261,7 +259,7 @@ func Test_CheckAndVoteInboundTokenGas(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(tx)) lastBlock := receipt.BlockNumber.Uint64() + confirmation - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, lastBlock, chainParam) ballot, err := ob.CheckAndVoteInboundTokenGas(ctx, tx, receipt, false) require.NoError(t, err) require.Equal(t, "", ballot) @@ -278,7 +276,7 @@ func Test_BuildInboundVoteMsgForZetaSentEvent(t *testing.T) { cctx := testutils.LoadCctxByInbound(t, chainID, coin.CoinType_Zeta, inboundHash) // parse ZetaSent event - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, mocks.MockChainParams(1, 1)) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, mocks.MockChainParams(1, 1)) connector := mocks.MockConnectorNonEth(t, chainID) event := testutils.ParseReceiptZetaSent(receipt, connector) @@ -327,7 +325,7 @@ func Test_BuildInboundVoteMsgForDepositedEvent(t *testing.T) { cctx := testutils.LoadCctxByInbound(t, chainID, coin.CoinType_ERC20, inboundHash) // parse Deposited event - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, mocks.MockChainParams(1, 1)) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, mocks.MockChainParams(1, 1)) custody := mocks.MockERC20Custody(t, chainID) event := testutils.ParseReceiptERC20Deposited(receipt, custody) sender := ethcommon.HexToAddress(tx.From) @@ -385,7 +383,7 @@ func Test_BuildInboundVoteMsgForTokenSentToTSS(t *testing.T) { require.NoError(t, evm.ValidateEvmTransaction(txDonation)) // create test compliance config - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, mocks.MockChainParams(1, 1)) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, mocks.MockChainParams(1, 1)) cfg := config.Config{ ComplianceConfig: config.ComplianceConfig{}, } @@ -462,7 +460,7 @@ func Test_ObserveTSSReceiveInBlock(t *testing.T) { ctx := context.Background() t.Run("should observe TSS receive in block", func(t *testing.T) { - ob := MockEVMObserver(t, chain, evmClient, evmJSONRPC, zetacoreClient, tss, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, evmClient, evmJSONRPC, zetacoreClient, tss, lastBlock, chainParam) // feed archived block and receipt evmJSONRPC.WithBlock(block) @@ -471,13 +469,13 @@ func Test_ObserveTSSReceiveInBlock(t *testing.T) { require.NoError(t, err) }) t.Run("should not observe on error getting block", func(t *testing.T) { - ob := MockEVMObserver(t, chain, evmClient, evmJSONRPC, zetacoreClient, tss, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, evmClient, evmJSONRPC, zetacoreClient, tss, lastBlock, chainParam) err := ob.ObserveTSSReceiveInBlock(ctx, blockNumber) // error getting block is expected because the mock JSONRPC contains no block require.ErrorContains(t, err, "error getting block") }) t.Run("should not observe on error getting receipt", func(t *testing.T) { - ob := MockEVMObserver(t, chain, evmClient, evmJSONRPC, zetacoreClient, tss, memDBPath, lastBlock, chainParam) + ob := MockEVMObserver(t, chain, evmClient, evmJSONRPC, zetacoreClient, tss, lastBlock, chainParam) evmJSONRPC.WithBlock(block) err := ob.ObserveTSSReceiveInBlock(ctx, blockNumber) // error getting block is expected because the mock evmClient contains no receipt diff --git a/zetaclient/chains/evm/observer/observer.go b/zetaclient/chains/evm/observer/observer.go index a6b69a78b1..dbfafd23e1 100644 --- a/zetaclient/chains/evm/observer/observer.go +++ b/zetaclient/chains/evm/observer/observer.go @@ -26,6 +26,7 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/chains/evm" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" "github.com/zeta-chain/zetacore/zetaclient/config" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/metrics" clienttypes "github.com/zeta-chain/zetacore/zetaclient/types" ) @@ -61,7 +62,7 @@ func NewObserver( chainParams observertypes.ChainParams, zetacoreClient interfaces.ZetacoreClient, tss interfaces.TSSSigner, - dbpath string, + database *db.DB, logger base.Logger, ts *metrics.TelemetryServer, ) (*Observer, error) { @@ -74,10 +75,11 @@ func NewObserver( base.DefaultBlockCacheSize, base.DefaultHeaderCacheSize, ts, + database, logger, ) if err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to create base observer") } // create evm observer @@ -90,10 +92,9 @@ func NewObserver( outboundConfirmedTransactions: make(map[string]*ethtypes.Transaction), } - // open database and load data - err = ob.LoadDB(ctx, dbpath) - if err != nil { - return nil, err + // load last block scanned + if err = ob.LoadLastBlockScanned(ctx); err != nil { + return nil, errors.Wrap(err, "unable to load last block scanned") } return ob, nil @@ -166,6 +167,11 @@ func FetchZetaTokenContract( // Start all observation routines for the evm chain func (ob *Observer) Start(ctx context.Context) { + if noop := ob.Observer.Start(); noop { + ob.Logger().Chain.Info().Msgf("observer is already started for chain %d", ob.Chain().ChainId) + return + } + ob.Logger().Chain.Info().Msgf("observer is starting for chain %d", ob.Chain().ChainId) bg.Work(ctx, ob.WatchInbound, bg.WithName("WatchInbound"), bg.WithLogger(ob.Logger().Inbound)) @@ -422,35 +428,6 @@ func (ob *Observer) BlockByNumber(blockNumber int) (*ethrpc.Block, error) { return block, nil } -// LoadDB open sql database and load data into EVM observer -// TODO(revamp): move to a db file -func (ob *Observer) LoadDB(ctx context.Context, dbPath string) error { - if dbPath == "" { - return errors.New("empty db path") - } - - // open database - err := ob.OpenDB(dbPath, "") - if err != nil { - return errors.Wrapf(err, "error OpenDB for chain %d", ob.Chain().ChainId) - } - - // run auto migration - // transaction and receipt tables are used nowhere but we still run migration in case they are needed in future - err = ob.DB().AutoMigrate( - &clienttypes.ReceiptSQLType{}, - &clienttypes.TransactionSQLType{}, - ) - if err != nil { - return errors.Wrapf(err, "error AutoMigrate for chain %d", ob.Chain().ChainId) - } - - // load last block scanned - err = ob.LoadLastBlockScanned(ctx) - - return err -} - // LoadLastBlockScanned loads the last scanned block from the database // TODO(revamp): move to a db file func (ob *Observer) LoadLastBlockScanned(ctx context.Context) error { diff --git a/zetaclient/chains/evm/observer/observer_test.go b/zetaclient/chains/evm/observer/observer_test.go index c24b5da8f7..da62eb8159 100644 --- a/zetaclient/chains/evm/observer/observer_test.go +++ b/zetaclient/chains/evm/observer/observer_test.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/require" zctx "github.com/zeta-chain/zetacore/zetaclient/context" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/keys" "github.com/zeta-chain/zetacore/pkg/chains" @@ -81,7 +82,6 @@ func MockEVMObserver( evmJSONRPC interfaces.EVMJSONRPCClient, zetacoreClient interfaces.ZetacoreClient, tss interfaces.TSSSigner, - dbpath string, lastBlock uint64, params observertypes.ChainParams, ) *observer.Observer { @@ -107,8 +107,11 @@ func MockEVMObserver( // create zetacore context _, evmCfg := getZetacoreContext(chain, "", ¶ms) + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + // create observer - ob, err := observer.NewObserver(ctx, evmCfg, evmClient, params, zetacoreClient, tss, dbpath, base.Logger{}, nil) + ob, err := observer.NewObserver(ctx, evmCfg, evmClient, params, zetacoreClient, tss, database, base.Logger{}, nil) require.NoError(t, err) ob.WithEvmJSONRPC(evmJSONRPC) ob.WithLastBlock(lastBlock) @@ -130,8 +133,9 @@ func Test_NewObserver(t *testing.T) { chainParams observertypes.ChainParams evmClient interfaces.EVMRPCClient tss interfaces.TSSSigner - dbpath string logger base.Logger + before func() + after func() ts *metrics.TelemetryServer fail bool message string @@ -145,40 +149,45 @@ func Test_NewObserver(t *testing.T) { chainParams: params, evmClient: mocks.NewMockEvmClient().WithBlockNumber(1000), tss: mocks.NewTSSMainnet(), - dbpath: sample.CreateTempDir(t), logger: base.Logger{}, ts: nil, fail: false, }, { - name: "should fail on invalid dbpath", + name: "should fail if RPC call fails", evmCfg: config.EVMConfig{ Chain: chain, Endpoint: "http://localhost:8545", }, chainParams: params, - evmClient: mocks.NewMockEvmClient().WithBlockNumber(1000), + evmClient: mocks.NewMockEvmClient().WithError(fmt.Errorf("error RPC")), tss: mocks.NewTSSMainnet(), - dbpath: "/invalid/dbpath", // invalid dbpath logger: base.Logger{}, ts: nil, fail: true, - message: "error creating db path", + message: "error RPC", }, { - name: "should fail if RPC call fails", + name: "should fail on invalid ENV var", evmCfg: config.EVMConfig{ Chain: chain, Endpoint: "http://localhost:8545", }, chainParams: params, - evmClient: mocks.NewMockEvmClient().WithError(fmt.Errorf("error RPC")), + evmClient: mocks.NewMockEvmClient().WithBlockNumber(1000), tss: mocks.NewTSSMainnet(), - dbpath: sample.CreateTempDir(t), - logger: base.Logger{}, - ts: nil, - fail: true, - message: "error RPC", + before: func() { + envVar := base.EnvVarLatestBlockByChain(chain) + os.Setenv(envVar, "invalid") + }, + after: func() { + envVar := base.EnvVarLatestBlockByChain(chain) + os.Unsetenv(envVar) + }, + logger: base.Logger{}, + ts: nil, + fail: true, + message: "unable to load last block scanned", }, } @@ -189,6 +198,16 @@ func Test_NewObserver(t *testing.T) { //zetacoreCtx, _ := getZetacoreContext(tt.evmCfg.Chain, tt.evmCfg.Endpoint, ¶ms) zetacoreClient := mocks.NewZetacoreClient(t) + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + + if tt.before != nil { + tt.before() + } + if tt.after != nil { + defer tt.after() + } + // create observer ob, err := observer.NewObserver( ctx, @@ -197,7 +216,7 @@ func Test_NewObserver(t *testing.T) { tt.chainParams, zetacoreClient, tt.tss, - tt.dbpath, + database, tt.logger, tt.ts, ) @@ -214,53 +233,6 @@ func Test_NewObserver(t *testing.T) { } } -func Test_LoadDB(t *testing.T) { - ctx := context.Background() - - // use Ethereum chain for testing - chain := chains.Ethereum - params := mocks.MockChainParams(chain.ChainId, 10) - dbpath := sample.CreateTempDir(t) - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, dbpath, 1, params) - - t.Run("should load db successfully", func(t *testing.T) { - err := ob.LoadDB(ctx, dbpath) - require.NoError(t, err) - require.EqualValues(t, 1000, ob.LastBlockScanned()) - }) - t.Run("should fail on invalid dbpath", func(t *testing.T) { - // load db with empty dbpath - err := ob.LoadDB(ctx, "") - require.ErrorContains(t, err, "empty db path") - - // load db with invalid dbpath - err = ob.LoadDB(ctx, "/invalid/dbpath") - require.ErrorContains(t, err, "error OpenDB") - }) - t.Run("should fail on invalid env var", func(t *testing.T) { - // set invalid environment variable - envvar := base.EnvVarLatestBlockByChain(chain) - os.Setenv(envvar, "invalid") - defer os.Unsetenv(envvar) - - // load db - err := ob.LoadDB(ctx, dbpath) - require.ErrorContains(t, err, "error LoadLastBlockScanned") - }) - t.Run("should fail on RPC error", func(t *testing.T) { - // create observer - tempClient := mocks.NewMockEvmClient() - ob := MockEVMObserver(t, chain, tempClient, nil, nil, nil, dbpath, 1, params) - - // set RPC error - tempClient.WithError(fmt.Errorf("error RPC")) - - // load db - err := ob.LoadDB(ctx, dbpath) - require.ErrorContains(t, err, "error RPC") - }) -} - func Test_LoadLastBlockScanned(t *testing.T) { ctx := context.Background() @@ -270,8 +242,7 @@ func Test_LoadLastBlockScanned(t *testing.T) { // create observer using mock evm client evmClient := mocks.NewMockEvmClient().WithBlockNumber(100) - dbpath := sample.CreateTempDir(t) - ob := MockEVMObserver(t, chain, evmClient, nil, nil, nil, dbpath, 1, params) + ob := MockEVMObserver(t, chain, evmClient, nil, nil, nil, 1, params) t.Run("should load last block scanned", func(t *testing.T) { // create db and write 123 as last block scanned @@ -294,8 +265,7 @@ func Test_LoadLastBlockScanned(t *testing.T) { }) t.Run("should fail on RPC error", func(t *testing.T) { // create observer on separate path, as we need to reset last block scanned - otherPath := sample.CreateTempDir(t) - obOther := MockEVMObserver(t, chain, evmClient, nil, nil, nil, otherPath, 1, params) + obOther := MockEVMObserver(t, chain, evmClient, nil, nil, nil, 1, params) // reset last block scanned to 0 so that it will be loaded from RPC obOther.WithLastBlockScanned(0) diff --git a/zetaclient/chains/evm/observer/outbound_test.go b/zetaclient/chains/evm/observer/outbound_test.go index 7342139343..8b0ad1573c 100644 --- a/zetaclient/chains/evm/observer/outbound_test.go +++ b/zetaclient/chains/evm/observer/outbound_test.go @@ -19,8 +19,6 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" ) -const memDBPath = testutils.SQLiteMemory - // getContractsByChainID is a helper func to get contracts and addresses by chainID func getContractsByChainID( t *testing.T, @@ -62,7 +60,7 @@ func Test_IsOutboundProcessed(t *testing.T) { t.Run("should post vote and return true if outbound is processed", func(t *testing.T) { // create evm observer and set outbound and receipt - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, chainParam) ob.SetTxNReceipt(nonce, receipt, outbound) // post outbound vote @@ -79,7 +77,7 @@ func Test_IsOutboundProcessed(t *testing.T) { cctx.InboundParams.Sender = sample.EthAddress().Hex() // create evm observer and set outbound and receipt - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, chainParam) ob.SetTxNReceipt(nonce, receipt, outbound) // modify compliance config to restrict sender address @@ -97,7 +95,7 @@ func Test_IsOutboundProcessed(t *testing.T) { }) t.Run("should return false if outbound is not confirmed", func(t *testing.T) { // create evm observer and DO NOT set outbound as confirmed - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, chainParam) isIncluded, isConfirmed, err := ob.IsOutboundProcessed(ctx, cctx, zerolog.Nop()) require.NoError(t, err) require.False(t, isIncluded) @@ -105,7 +103,7 @@ func Test_IsOutboundProcessed(t *testing.T) { }) t.Run("should fail if unable to parse ZetaReceived event", func(t *testing.T) { // create evm observer and set outbound and receipt - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, chainParam) ob.SetTxNReceipt(nonce, receipt, outbound) // set connector contract address to an arbitrary address to make event parsing fail @@ -155,7 +153,7 @@ func Test_IsOutboundProcessed_ContractError(t *testing.T) { t.Run("should fail if unable to get connector/custody contract", func(t *testing.T) { // create evm observer and set outbound and receipt - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, chainParam) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, chainParam) ob.SetTxNReceipt(nonce, receipt, outbound) abiConnector := zetaconnector.ZetaConnectorNonEthMetaData.ABI abiCustody := erc20custody.ERC20CustodyMetaData.ABI @@ -201,7 +199,7 @@ func Test_PostVoteOutbound(t *testing.T) { receiveStatus := chains.ReceiveStatus_success // create evm client using mock zetacore client and post outbound vote - ob := MockEVMObserver(t, chain, nil, nil, nil, nil, memDBPath, 1, observertypes.ChainParams{}) + ob := MockEVMObserver(t, chain, nil, nil, nil, nil, 1, observertypes.ChainParams{}) ob.PostVoteOutbound( ctx, cctx.Index, diff --git a/zetaclient/chains/evm/signer/signer_test.go b/zetaclient/chains/evm/signer/signer_test.go index 9880bb233e..b0cf3e5504 100644 --- a/zetaclient/chains/evm/signer/signer_test.go +++ b/zetaclient/chains/evm/signer/signer_test.go @@ -12,6 +12,7 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/require" zctx "github.com/zeta-chain/zetacore/zetaclient/context" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/keys" "github.com/zeta-chain/zetacore/pkg/chains" @@ -77,10 +78,12 @@ func getNewEvmChainObserver(t *testing.T, tss interfaces.TSSSigner) (*observer.O params := mocks.MockChainParams(evmcfg.Chain.ChainId, 10) cfg.EVMChainConfigs[chains.BscMainnet.ChainId] = evmcfg //appContext := context.New(cfg, zerolog.Nop()) - dbpath := sample.CreateTempDir(t) logger := base.Logger{} ts := &metrics.TelemetryServer{} + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + return observer.NewObserver( ctx, evmcfg, @@ -88,7 +91,7 @@ func getNewEvmChainObserver(t *testing.T, tss interfaces.TSSSigner) (*observer.O params, mocks.NewZetacoreClient(t), tss, - dbpath, + database, logger, ts, ) diff --git a/zetaclient/chains/solana/observer/db.go b/zetaclient/chains/solana/observer/db.go deleted file mode 100644 index 91757910d7..0000000000 --- a/zetaclient/chains/solana/observer/db.go +++ /dev/null @@ -1,30 +0,0 @@ -package observer - -import ( - "github.com/pkg/errors" -) - -// LoadDB open sql database and load data into Solana observer -func (ob *Observer) LoadDB(dbPath string) error { - if dbPath == "" { - return errors.New("empty db path") - } - - // open database - err := ob.OpenDB(dbPath, "") - if err != nil { - return errors.Wrapf(err, "error OpenDB for chain %d", ob.Chain().ChainId) - } - - ob.Observer.LoadLastTxScanned() - - return nil -} - -// LoadLastTxScanned loads the last scanned tx from the database. -func (ob *Observer) LoadLastTxScanned() error { - ob.Observer.LoadLastTxScanned() - ob.Logger().Chain.Info().Msgf("chain %d starts scanning from tx %s", ob.Chain().ChainId, ob.LastTxScanned()) - - return nil -} diff --git a/zetaclient/chains/solana/observer/inbound_test.go b/zetaclient/chains/solana/observer/inbound_test.go index 2b3692ad62..a7090d6821 100644 --- a/zetaclient/chains/solana/observer/inbound_test.go +++ b/zetaclient/chains/solana/observer/inbound_test.go @@ -12,6 +12,7 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/chains/base" "github.com/zeta-chain/zetacore/zetaclient/chains/solana/observer" "github.com/zeta-chain/zetacore/zetaclient/config" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/keys" "github.com/zeta-chain/zetacore/zetaclient/testutils" "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" @@ -30,14 +31,16 @@ func Test_FilterInboundEventAndVote(t *testing.T) { chain := chains.SolanaDevnet txResult := testutils.LoadSolanaInboundTxResult(t, TestDataDir, chain.ChainId, txHash, false) + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + // create observer chainParams := sample.ChainParams(chain.ChainId) chainParams.GatewayAddress = "2kJndCL9NBR36ySiQ4bmArs4YgWQu67LmCDfLzk5Gb7s" zetacoreClient := mocks.NewZetacoreClient(t) zetacoreClient.WithKeys(&keys.Keys{}).WithZetaChain().WithPostVoteInbound("", "") - dbpath := sample.CreateTempDir(t) - ob, err := observer.NewObserver(chain, nil, *chainParams, zetacoreClient, nil, dbpath, base.DefaultLogger(), nil) + ob, err := observer.NewObserver(chain, nil, *chainParams, zetacoreClient, nil, database, base.DefaultLogger(), nil) require.NoError(t, err) t.Run("should filter inbound events and vote", func(t *testing.T) { @@ -53,11 +56,14 @@ func Test_FilterInboundEvents(t *testing.T) { chain := chains.SolanaDevnet txResult := testutils.LoadSolanaInboundTxResult(t, TestDataDir, chain.ChainId, txHash, false) + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + // create observer chainParams := sample.ChainParams(chain.ChainId) chainParams.GatewayAddress = "2kJndCL9NBR36ySiQ4bmArs4YgWQu67LmCDfLzk5Gb7s" - dbpath := sample.CreateTempDir(t) - ob, err := observer.NewObserver(chain, nil, *chainParams, nil, nil, dbpath, base.DefaultLogger(), nil) + + ob, err := observer.NewObserver(chain, nil, *chainParams, nil, nil, database, base.DefaultLogger(), nil) require.NoError(t, err) // expected result @@ -94,8 +100,10 @@ func Test_BuildInboundVoteMsgFromEvent(t *testing.T) { zetacoreClient := mocks.NewZetacoreClient(t) zetacoreClient.WithKeys(&keys.Keys{}).WithZetaChain().WithPostVoteInbound("", "") - dbpath := sample.CreateTempDir(t) - ob, err := observer.NewObserver(chain, nil, *params, zetacoreClient, nil, dbpath, base.DefaultLogger(), nil) + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + + ob, err := observer.NewObserver(chain, nil, *params, zetacoreClient, nil, database, base.DefaultLogger(), nil) require.NoError(t, err) // create test compliance config @@ -156,11 +164,13 @@ func Test_ParseInboundAsDeposit(t *testing.T) { tx, err := txResult.Transaction.GetTransaction() require.NoError(t, err) + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + // create observer chainParams := sample.ChainParams(chain.ChainId) chainParams.GatewayAddress = "2kJndCL9NBR36ySiQ4bmArs4YgWQu67LmCDfLzk5Gb7s" - dbpath := sample.CreateTempDir(t) - ob, err := observer.NewObserver(chain, nil, *chainParams, nil, nil, dbpath, base.DefaultLogger(), nil) + ob, err := observer.NewObserver(chain, nil, *chainParams, nil, nil, database, base.DefaultLogger(), nil) require.NoError(t, err) // expected result diff --git a/zetaclient/chains/solana/observer/observer.go b/zetaclient/chains/solana/observer/observer.go index ec818732ca..1cdb3f6768 100644 --- a/zetaclient/chains/solana/observer/observer.go +++ b/zetaclient/chains/solana/observer/observer.go @@ -4,6 +4,7 @@ import ( "context" "github.com/gagliardetto/solana-go" + "github.com/pkg/errors" "github.com/zeta-chain/zetacore/pkg/bg" "github.com/zeta-chain/zetacore/pkg/chains" @@ -11,6 +12,7 @@ import ( observertypes "github.com/zeta-chain/zetacore/x/observer/types" "github.com/zeta-chain/zetacore/zetaclient/chains/base" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/metrics" ) @@ -38,7 +40,7 @@ func NewObserver( chainParams observertypes.ChainParams, zetacoreClient interfaces.ZetacoreClient, tss interfaces.TSSSigner, - dbpath string, + db *db.DB, logger base.Logger, ts *metrics.TelemetryServer, ) (*Observer, error) { @@ -51,17 +53,23 @@ func NewObserver( base.DefaultBlockCacheSize, base.DefaultHeaderCacheSize, ts, + db, logger, ) if err != nil { return nil, err } + pubKey, err := solana.PublicKeyFromBase58(chainParams.GatewayAddress) + if err != nil { + return nil, errors.Wrap(err, "unable to derive public key") + } + // create solana observer ob := Observer{ Observer: *baseObserver, solClient: solClient, - gatewayID: solana.MustPublicKeyFromBase58(chainParams.GatewayAddress), + gatewayID: pubKey, } // compute gateway PDA @@ -71,11 +79,7 @@ func NewObserver( return nil, err } - // load btc chain observer DB - err = ob.LoadDB(dbpath) - if err != nil { - return nil, err - } + ob.Observer.LoadLastTxScanned() return &ob, nil } @@ -108,6 +112,11 @@ func (ob *Observer) GetChainParams() observertypes.ChainParams { // Start starts the Go routine processes to observe the Solana chain func (ob *Observer) Start(ctx context.Context) { + if noop := ob.Observer.Start(); noop { + ob.Logger().Chain.Info().Msgf("observer is already started for chain %d", ob.Chain().ChainId) + return + } + ob.Logger().Chain.Info().Msgf("observer is starting for chain %d", ob.Chain().ChainId) // watch Solana chain for incoming txs and post votes to zetacore @@ -116,3 +125,11 @@ func (ob *Observer) Start(ctx context.Context) { // watch zetacore for Solana inbound trackers bg.Work(ctx, ob.WatchInboundTracker, bg.WithName("WatchInboundTracker"), bg.WithLogger(ob.Logger().Inbound)) } + +// LoadLastTxScanned loads the last scanned tx from the database. +func (ob *Observer) LoadLastTxScanned() error { + ob.Observer.LoadLastTxScanned() + ob.Logger().Chain.Info().Msgf("chain %d starts scanning from tx %s", ob.Chain().ChainId, ob.LastTxScanned()) + + return nil +} diff --git a/zetaclient/chains/solana/observer/db_test.go b/zetaclient/chains/solana/observer/observer_test.go similarity index 64% rename from zetaclient/chains/solana/observer/db_test.go rename to zetaclient/chains/solana/observer/observer_test.go index f6fdee1b73..5b2576e66d 100644 --- a/zetaclient/chains/solana/observer/db_test.go +++ b/zetaclient/chains/solana/observer/observer_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/require" observertypes "github.com/zeta-chain/zetacore/x/observer/types" + "github.com/zeta-chain/zetacore/zetaclient/db" "github.com/zeta-chain/zetacore/zetaclient/keys" "github.com/zeta-chain/zetacore/pkg/chains" @@ -23,7 +24,6 @@ func MockSolanaObserver( chainParams observertypes.ChainParams, zetacoreClient interfaces.ZetacoreClient, tss interfaces.TSSSigner, - dbpath string, ) *observer.Observer { // use mock zetacore client if not provided if zetacoreClient == nil { @@ -34,6 +34,9 @@ func MockSolanaObserver( tss = mocks.NewTSSMainnet() } + database, err := db.NewFromSqliteInMemory(true) + require.NoError(t, err) + // create observer ob, err := observer.NewObserver( chain, @@ -41,7 +44,7 @@ func MockSolanaObserver( chainParams, zetacoreClient, tss, - dbpath, + database, base.DefaultLogger(), nil, ) @@ -50,45 +53,14 @@ func MockSolanaObserver( return ob } -func Test_LoadDB(t *testing.T) { - // parepare params - chain := chains.SolanaDevnet - params := sample.ChainParams(chain.ChainId) - params.GatewayAddress = sample.SolanaAddress(t) - dbpath := sample.CreateTempDir(t) - - // create observer - ob := MockSolanaObserver(t, chain, nil, *params, nil, nil, dbpath) - - // write last tx to db - lastTx := sample.SolanaSignature(t).String() - ob.WriteLastTxScannedToDB(lastTx) - - t.Run("should load db successfully", func(t *testing.T) { - err := ob.LoadDB(dbpath) - require.NoError(t, err) - require.Equal(t, lastTx, ob.LastTxScanned()) - }) - t.Run("should fail on invalid dbpath", func(t *testing.T) { - // load db with empty dbpath - err := ob.LoadDB("") - require.ErrorContains(t, err, "empty db path") - - // load db with invalid dbpath - err = ob.LoadDB("/invalid/dbpath") - require.ErrorContains(t, err, "error OpenDB") - }) -} - func Test_LoadLastTxScanned(t *testing.T) { // parepare params chain := chains.SolanaDevnet params := sample.ChainParams(chain.ChainId) params.GatewayAddress = sample.SolanaAddress(t) - dbpath := sample.CreateTempDir(t) // create observer - ob := MockSolanaObserver(t, chain, nil, *params, nil, nil, dbpath) + ob := MockSolanaObserver(t, chain, nil, *params, nil, nil) t.Run("should load last block scanned", func(t *testing.T) { // write sample last tx to db diff --git a/zetaclient/context/app.go b/zetaclient/context/app.go index 139d7c5bef..72310c22af 100644 --- a/zetaclient/context/app.go +++ b/zetaclient/context/app.go @@ -38,33 +38,16 @@ type AppContext struct { mu sync.RWMutex } -// New creates and returns new AppContext +// New creates and returns new empty AppContext func New(cfg config.Config, logger zerolog.Logger) *AppContext { - evmChainParams := make(map[int64]*observertypes.ChainParams) - for _, e := range cfg.EVMChainConfigs { - evmChainParams[e.Chain.ChainId] = &observertypes.ChainParams{} - } - - var bitcoinChainParams *observertypes.ChainParams - _, found := cfg.GetBTCConfig() - if found { - bitcoinChainParams = &observertypes.ChainParams{} - } - - var solanaChainParams *observertypes.ChainParams - _, found = cfg.GetSolanaConfig() - if found { - solanaChainParams = &observertypes.ChainParams{} - } - return &AppContext{ config: cfg, logger: logger.With().Str("module", "appcontext").Logger(), chainsEnabled: []chains.Chain{}, - evmChainParams: evmChainParams, - bitcoinChainParams: bitcoinChainParams, - solanaChainParams: solanaChainParams, + evmChainParams: map[int64]*observertypes.ChainParams{}, + bitcoinChainParams: nil, + solanaChainParams: nil, crosschainFlags: observertypes.CrosschainFlags{}, blockHeaderEnabledChains: []lightclienttypes.HeaderSupportedChain{}, @@ -81,14 +64,17 @@ func (a *AppContext) Config() config.Config { // GetBTCChainAndConfig returns btc chain and config if enabled func (a *AppContext) GetBTCChainAndConfig() (chains.Chain, config.BTCConfig, bool) { - btcConfig, configEnabled := a.Config().GetBTCConfig() - btcChain, _, paramsEnabled := a.GetBTCChainParams() + cfg, configEnabled := a.Config().GetBTCConfig() + if !configEnabled { + return chains.Chain{}, config.BTCConfig{}, false + } - if !configEnabled || !paramsEnabled { + chain, _, paramsEnabled := a.GetBTCChainParams() + if !paramsEnabled { return chains.Chain{}, config.BTCConfig{}, false } - return btcChain, btcConfig, true + return chain, cfg, true } // GetSolanaChainAndConfig returns solana chain and config if enabled @@ -277,15 +263,15 @@ func (a *AppContext) Update( blockHeaderEnabledChains []lightclienttypes.HeaderSupportedChain, init bool, ) { + if len(newChains) == 0 { + a.logger.Warn().Msg("UpdateChainParams: No chains enabled in ZeroCore") + } + // Ignore whatever order zetacore organizes chain list in state sort.SliceStable(newChains, func(i, j int) bool { return newChains[i].ChainId < newChains[j].ChainId }) - if len(newChains) == 0 { - a.logger.Warn().Msg("UpdateChainParams: No chains enabled in ZeroCore") - } - a.mu.Lock() defer a.mu.Unlock() @@ -294,7 +280,7 @@ func (a *AppContext) Update( a.logger.Warn(). Interface("chains.current", a.chainsEnabled). Interface("chains.new", newChains). - Msg("UpdateChainParams: ChainsEnabled changed at runtime!") + Msg("ChainsEnabled changed at runtime!") } if keygen != nil { @@ -306,25 +292,37 @@ func (a *AppContext) Update( a.additionalChain = additionalChains a.blockHeaderEnabledChains = blockHeaderEnabledChains + // update core params for evm chains we have configs in file + freshEvmChainParams := make(map[int64]*observertypes.ChainParams) + for _, cp := range evmChainParams { + _, found := a.config.EVMChainConfigs[cp.ChainId] + if !found { + a.logger.Warn(). + Int64("chain.id", cp.ChainId). + Msg("Encountered EVM ChainParams that are not present in the config file") + + continue + } + + if chains.IsZetaChain(cp.ChainId, nil) { + continue + } + + freshEvmChainParams[cp.ChainId] = cp + } + + a.evmChainParams = freshEvmChainParams + // update chain params for bitcoin if it has config in file - if a.bitcoinChainParams != nil && btcChainParams != nil { + if btcChainParams != nil { a.bitcoinChainParams = btcChainParams } // update chain params for solana if it has config in file - if a.solanaChainParams != nil && solChainParams != nil { + if solChainParams != nil { a.solanaChainParams = solChainParams } - // update core params for evm chains we have configs in file - for _, params := range evmChainParams { - _, found := a.evmChainParams[params.ChainId] - if !found { - continue - } - a.evmChainParams[params.ChainId] = params - } - if tssPubKey != "" { a.currentTssPubkey = tssPubKey } diff --git a/zetaclient/context/app_test.go b/zetaclient/context/app_test.go index 1786242b20..3e19772578 100644 --- a/zetaclient/context/app_test.go +++ b/zetaclient/context/app_test.go @@ -6,6 +6,7 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/testutil/sample" @@ -69,37 +70,48 @@ func TestNew(t *testing.T) { }) t.Run("should create new zetacore context with config containing evm chain params", func(t *testing.T) { - testCfg := config.New(false) + // ARRANGE + var ( + eth = chains.Ethereum.ChainId + matic = chains.Polygon.ChainId + + testCfg = config.New(false) + + ethChainParams = mocks.MockChainParams(eth, 200) + maticChainParams = mocks.MockChainParams(matic, 333) + ) + + // Given config with evm chain params (e.g. from a file) testCfg.EVMChainConfigs = map[int64]config.EVMConfig{ - 1: { - Chain: chains.Chain{ - ChainName: 1, - ChainId: 1, - }, - }, - 2: { - Chain: chains.Chain{ - ChainName: 2, - ChainId: 2, - }, - }, + eth: {Chain: chains.Ethereum}, + matic: {Chain: chains.Polygon}, } + + // And chain params from zetacore + chainParams := map[int64]*observertypes.ChainParams{ + eth: ðChainParams, + matic: &maticChainParams, + } + + // Given app context appContext := context.New(testCfg, logger) - require.NotNil(t, appContext) + + // That was updated with chain params + appContext.Update(nil, nil, chainParams, nil, nil, "", observertypes.CrosschainFlags{}, nil, nil, false) // assert evm chain params allEVMChainParams := appContext.GetAllEVMChainParams() require.Equal(t, 2, len(allEVMChainParams)) - require.Equal(t, &observertypes.ChainParams{}, allEVMChainParams[1]) - require.Equal(t, &observertypes.ChainParams{}, allEVMChainParams[2]) + require.Equal(t, ðChainParams, allEVMChainParams[eth]) + require.Equal(t, &maticChainParams, allEVMChainParams[matic]) - evmChainParams1, found := appContext.GetEVMChainParams(1) + evmChainParams1, found := appContext.GetEVMChainParams(eth) require.True(t, found) - require.Equal(t, &observertypes.ChainParams{}, evmChainParams1) + require.Equal(t, ðChainParams, evmChainParams1) - evmChainParams2, found := appContext.GetEVMChainParams(2) + evmChainParams2, found := appContext.GetEVMChainParams(matic) require.True(t, found) - require.Equal(t, &observertypes.ChainParams{}, evmChainParams2) + require.Equal(t, &maticChainParams, evmChainParams2) }) t.Run("should create new zetacore context with config containing btc config", func(t *testing.T) { diff --git a/zetaclient/db/db.go b/zetaclient/db/db.go new file mode 100644 index 0000000000..49a3b50b3f --- /dev/null +++ b/zetaclient/db/db.go @@ -0,0 +1,112 @@ +// Package db represents API for database operations. +package db + +import ( + "fmt" + "os" + "strings" + + "github.com/pkg/errors" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/zeta-chain/zetacore/zetaclient/types" +) + +// SqliteInMemory is a special string to use in-memory database. +// @see https://www.sqlite.org/inmemorydb.html +const SqliteInMemory = ":memory:" + +// read/write/execute for user +// read/write for group +const dirCreationMode = 0o750 + +var ( + defaultGormConfig = &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + } + migrationEntities = []any{ + &types.LastBlockSQLType{}, + &types.TransactionSQLType{}, + &types.ReceiptSQLType{}, + &types.TransactionResultSQLType{}, + &types.OutboundHashSQLType{}, + &types.LastTransactionSQLType{}, + } +) + +// DB database. +type DB struct { + db *gorm.DB +} + +// NewFromSqlite creates a new instance of DB based on SQLite database. +func NewFromSqlite(directory, dbName string, migrate bool) (*DB, error) { + path, err := ensurePath(directory, dbName) + if err != nil { + return nil, errors.Wrap(err, "unable to ensure database path") + } + + return New(sqlite.Open(path), migrate) +} + +// NewFromSqliteInMemory creates a new instance of DB based on SQLite in-memory database. +func NewFromSqliteInMemory(migrate bool) (*DB, error) { + return NewFromSqlite(SqliteInMemory, "", migrate) +} + +// New creates a new instance of DB. +func New(dial gorm.Dialector, migrate bool) (*DB, error) { + // open db + db, err := gorm.Open(dial, defaultGormConfig) + if err != nil { + return nil, errors.Wrap(err, "unable to open gorm database") + } + + if migrate { + if err := db.AutoMigrate(migrationEntities...); err != nil { + return nil, errors.Wrap(err, "unable to migrate database") + } + } + + return &DB{db}, nil +} + +// Client returns the underlying gorm database. +func (db *DB) Client() *gorm.DB { + return db.db +} + +// Close closes the database. +func (db *DB) Close() error { + sqlDB, err := db.db.DB() + if err != nil { + return errors.Wrap(err, "unable to get underlying sql.DB") + } + + if err := sqlDB.Close(); err != nil { + return errors.Wrap(err, "unable to close sql.DB") + } + + return nil +} + +func ensurePath(directory, dbName string) (string, error) { + // pass in-memory database as is + if strings.Contains(directory, SqliteInMemory) { + return directory, nil + } + + _, err := os.Stat(directory) + switch { + case os.IsNotExist(err): + if err := os.MkdirAll(directory, dirCreationMode); err != nil { + return "", errors.Wrapf(err, "unable to create database path %q", directory) + } + case err != nil: + return "", errors.Wrap(err, "unable to check database path") + } + + return fmt.Sprintf("%s/%s", directory, dbName), nil +} diff --git a/zetaclient/db/db_test.go b/zetaclient/db/db_test.go new file mode 100644 index 0000000000..da5ab2f6ff --- /dev/null +++ b/zetaclient/db/db_test.go @@ -0,0 +1,96 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/zetaclient/types" +) + +func TestNew(t *testing.T) { + t.Run("in memory alias", func(t *testing.T) { + // ARRANGE + // Given a database + db, err := NewFromSqliteInMemory(true) + require.NoError(t, err) + require.NotNil(t, db) + + // ACT + runSampleSetGetTest(t, db) + + // Close the database + assert.NoError(t, db.Close()) + }) + + t.Run("in memory", func(t *testing.T) { + // ARRANGE + // Given a database + db, err := NewFromSqlite(SqliteInMemory, "", true) + require.NoError(t, err) + require.NotNil(t, db) + + // ACT + runSampleSetGetTest(t, db) + + // Close the database + assert.NoError(t, db.Close()) + }) + + t.Run("file based", func(t *testing.T) { + // ARRANGE + // Given a tmp path + directory, dbName := t.TempDir(), "test.db" + + // Given a database + db, err := NewFromSqlite(directory, dbName, true) + require.NoError(t, err) + require.NotNil(t, db) + + // Check that the database file exists + assert.FileExists(t, directory+"/"+dbName) + + // ACT + runSampleSetGetTest(t, db) + + // Close the database + assert.NoError(t, db.Close()) + + t.Run("close twice", func(t *testing.T) { + require.NoError(t, db.Close()) + }) + }) + + t.Run("invalid file path", func(t *testing.T) { + // ARRANGE + // Given a tmp path + directory, dbName := "///hello", "test.db" + + // Given a database + db, err := NewFromSqlite(directory, dbName, true) + require.ErrorContains(t, err, "unable to ensure database path") + require.Nil(t, db) + }) +} + +func runSampleSetGetTest(t *testing.T, db *DB) { + // Given a dummy sql type + entity := types.ToLastBlockSQLType(444) + + // ACT #1 + // Create entity + result := db.Client().Create(&entity) + + // ASSERT + assert.NoError(t, result.Error) + + // ACT #2 + // Fetch entity + var entity2 types.LastBlockSQLType + + result = db.Client().First(&entity2) + + // ASSERT + assert.NoError(t, result.Error) + assert.Equal(t, entity.Num, entity2.Num) +} diff --git a/zetaclient/orchestrator/bootstap_test.go b/zetaclient/orchestrator/bootstap_test.go new file mode 100644 index 0000000000..555c830df5 --- /dev/null +++ b/zetaclient/orchestrator/bootstap_test.go @@ -0,0 +1,535 @@ +package orchestrator + +import ( + "context" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/pkg/ptr" + observertypes "github.com/zeta-chain/zetacore/x/observer/types" + "github.com/zeta-chain/zetacore/zetaclient/chains/base" + "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" + "github.com/zeta-chain/zetacore/zetaclient/config" + zctx "github.com/zeta-chain/zetacore/zetaclient/context" + "github.com/zeta-chain/zetacore/zetaclient/db" + "github.com/zeta-chain/zetacore/zetaclient/metrics" + "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" + "github.com/zeta-chain/zetacore/zetaclient/testutils/testrpc" +) + +const solanaGatewayAddress = "2kJndCL9NBR36ySiQ4bmArs4YgWQu67LmCDfLzk5Gb7s" + +func TestCreateSignerMap(t *testing.T) { + var ( + ts = metrics.NewTelemetryServer() + tss = mocks.NewTSSMainnet() + log = zerolog.New(zerolog.NewTestWriter(t)) + baseLogger = base.Logger{Std: log, Compliance: log} + ) + + t.Run("CreateSignerMap", func(t *testing.T) { + // ARRANGE + // Given a BTC server + _, btcConfig := testrpc.NewBtcServer(t) + + // Given a zetaclient config with ETH, MATIC, and BTC chains + cfg := config.New(false) + + cfg.EVMChainConfigs[chains.Ethereum.ChainId] = config.EVMConfig{ + Chain: chains.Ethereum, + Endpoint: mocks.EVMRPCEnabled, + } + + cfg.EVMChainConfigs[chains.Polygon.ChainId] = config.EVMConfig{ + Chain: chains.Polygon, + Endpoint: mocks.EVMRPCEnabled, + } + + cfg.BitcoinConfig = btcConfig + + // Given AppContext + app := zctx.New(cfg, log) + ctx := zctx.WithAppContext(context.Background(), app) + + // Given chain & chainParams "fetched" from zetacore + // (note that slice LACKS polygon chain on purpose) + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, chains.BitcoinMainnet, + }) + + // ACT + signers, err := CreateSignerMap(ctx, tss, baseLogger, ts) + + // ASSERT + assert.NoError(t, err) + assert.NotEmpty(t, signers) + + // Okay, now we want to check that signers for EVM and BTC were created + assert.Equal(t, 2, len(signers)) + hasSigner(t, signers, chains.Ethereum.ChainId) + hasSigner(t, signers, chains.BitcoinMainnet.ChainId) + + t.Run("Add polygon in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, chains.BitcoinMainnet, chains.Polygon, + }) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasSigner(t, signers, chains.Ethereum.ChainId) + hasSigner(t, signers, chains.Polygon.ChainId) + hasSigner(t, signers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Disable ethereum in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Polygon, chains.BitcoinMainnet, + }) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 1, removed) + + missesSigner(t, signers, chains.Ethereum.ChainId) + hasSigner(t, signers, chains.Polygon.ChainId) + hasSigner(t, signers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Re-enable ethereum in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, + chains.Polygon, + chains.BitcoinMainnet, + }) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasSigner(t, signers, chains.Ethereum.ChainId) + hasSigner(t, signers, chains.Polygon.ChainId) + hasSigner(t, signers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Disable btc in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, + chains.Polygon, + }) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 1, removed) + + hasSigner(t, signers, chains.Ethereum.ChainId) + hasSigner(t, signers, chains.Polygon.ChainId) + missesSigner(t, signers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Re-enable btc in the runtime", func(t *testing.T) { + // ARRANGE + // Given updated data from zetacore containing polygon chain + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, + chains.Polygon, + chains.BitcoinMainnet, + }) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasSigner(t, signers, chains.Ethereum.ChainId) + hasSigner(t, signers, chains.Polygon.ChainId) + hasSigner(t, signers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Polygon is there but not supported, should be disabled", func(t *testing.T) { + // ARRANGE + // Given updated data from zetacore containing polygon chain + supportedChain, evmParams, btcParams, solParams := chainParams([]chains.Chain{ + chains.Ethereum, + chains.Polygon, + chains.BitcoinMainnet, + }) + + // BUT (!) it's disabled via zetacore + evmParams[chains.Polygon.ChainId].IsSupported = false + + mustUpdateAppContext(t, app, supportedChain, evmParams, btcParams, solParams) + + // Should have signer BEFORE disabling + hasSigner(t, signers, chains.Polygon.ChainId) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 1, removed) + + hasSigner(t, signers, chains.Ethereum.ChainId) + missesSigner(t, signers, chains.Polygon.ChainId) + hasSigner(t, signers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("No changes", func(t *testing.T) { + // ARRANGE + before := len(signers) + + // ACT + added, removed, err := syncSignerMap(ctx, tss, baseLogger, ts, &signers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 0, removed) + assert.Equal(t, before, len(signers)) + }) + }) +} + +func TestCreateChainObserverMap(t *testing.T) { + var ( + ts = metrics.NewTelemetryServer() + tss = mocks.NewTSSMainnet() + log = zerolog.New(zerolog.NewTestWriter(t)) + baseLogger = base.Logger{Std: log, Compliance: log} + client = mocks.NewZetacoreClient(t) + dbPath = db.SqliteInMemory + ) + + t.Run("CreateChainObserverMap", func(t *testing.T) { + // ARRANGE + // Given a BTC server + btcServer, btcConfig := testrpc.NewBtcServer(t) + + btcServer.SetBlockCount(123) + + // Given generic EVM RPC + evmServer := testrpc.NewEVMServer(t) + evmServer.SetBlockNumber(100) + + // Given generic SOL RPC + _, solConfig := testrpc.NewSolanaServer(t) + + // Given a zetaclient config with ETH, MATIC, and BTC chains + cfg := config.New(false) + + cfg.EVMChainConfigs[chains.Ethereum.ChainId] = config.EVMConfig{ + Chain: chains.Ethereum, + Endpoint: evmServer.Endpoint, + } + + cfg.EVMChainConfigs[chains.Polygon.ChainId] = config.EVMConfig{ + Chain: chains.Polygon, + Endpoint: evmServer.Endpoint, + } + + cfg.BitcoinConfig = btcConfig + cfg.SolanaConfig = solConfig + + // Given AppContext + app := zctx.New(cfg, log) + ctx := zctx.WithAppContext(context.Background(), app) + + // Given chain & chainParams "fetched" from zetacore + // (note that slice LACKS polygon & SOL chains on purpose) + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, + chains.BitcoinMainnet, + }) + + // ACT + observers, err := CreateChainObserverMap(ctx, client, tss, dbPath, baseLogger, ts) + + // ASSERT + assert.NoError(t, err) + assert.NotEmpty(t, observers) + + // Okay, now we want to check that signers for EVM and BTC were created + assert.Equal(t, 2, len(observers)) + hasObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + + t.Run("Add polygon in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, chains.BitcoinMainnet, chains.Polygon, + }) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.Polygon.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Add solana in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, + chains.BitcoinMainnet, + chains.Polygon, + chains.SolanaMainnet, + }) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.Polygon.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + hasObserver(t, observers, chains.SolanaMainnet.ChainId) + }) + + t.Run("Disable ethereum and solana in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.BitcoinMainnet, + chains.Polygon, + }) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 2, removed) + + missesObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.Polygon.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + missesObserver(t, observers, chains.SolanaMainnet.ChainId) + }) + + t.Run("Re-enable ethereum in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, chains.BitcoinMainnet, chains.Polygon, + }) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.Polygon.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Disable btc in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.Ethereum, chains.Polygon, + }) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 1, removed) + + hasObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.Polygon.ChainId) + missesObserver(t, observers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Re-enable btc in the runtime", func(t *testing.T) { + // ARRANGE + mustUpdateAppContextChainParams(t, app, []chains.Chain{ + chains.BitcoinMainnet, chains.Ethereum, chains.Polygon, + }) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 1, added) + assert.Equal(t, 0, removed) + + hasObserver(t, observers, chains.Ethereum.ChainId) + hasObserver(t, observers, chains.Polygon.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("Polygon is there but not supported, should be disabled", func(t *testing.T) { + // ARRANGE + // Given updated data from zetacore containing polygon chain + supportedChain, evmParams, btcParams, solParams := chainParams([]chains.Chain{ + chains.Ethereum, + chains.Polygon, + chains.BitcoinMainnet, + }) + + // BUT (!) it's disabled via zetacore + evmParams[chains.Polygon.ChainId].IsSupported = false + + mustUpdateAppContext(t, app, supportedChain, evmParams, btcParams, solParams) + + // Should have signer BEFORE disabling + hasObserver(t, observers, chains.Polygon.ChainId) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 1, removed) + + hasObserver(t, observers, chains.Ethereum.ChainId) + missesObserver(t, observers, chains.Polygon.ChainId) + hasObserver(t, observers, chains.BitcoinMainnet.ChainId) + }) + + t.Run("No changes", func(t *testing.T) { + // ARRANGE + before := len(observers) + + // ACT + added, removed, err := syncObserverMap(ctx, client, tss, dbPath, baseLogger, ts, &observers) + + // ASSERT + assert.NoError(t, err) + assert.Equal(t, 0, added) + assert.Equal(t, 0, removed) + assert.Equal(t, before, len(observers)) + }) + }) +} + +func chainParams(supportedChains []chains.Chain) ( + []chains.Chain, + map[int64]*observertypes.ChainParams, + *observertypes.ChainParams, + *observertypes.ChainParams, +) { + var ( + evmParams = make(map[int64]*observertypes.ChainParams) + btcParams = &observertypes.ChainParams{} + solParams = &observertypes.ChainParams{} + ) + + for _, chain := range supportedChains { + if chains.IsBitcoinChain(chain.ChainId, nil) { + btcParams = &observertypes.ChainParams{ + ChainId: chain.ChainId, + IsSupported: true, + } + + continue + } + + if chains.IsSolanaChain(chain.ChainId, nil) { + solParams = &observertypes.ChainParams{ + ChainId: chain.ChainId, + IsSupported: true, + GatewayAddress: solanaGatewayAddress, + } + } + + if chains.IsEVMChain(chain.ChainId, nil) { + evmParams[chain.ChainId] = ptr.Ptr(mocks.MockChainParams(chain.ChainId, 100)) + } + } + + return supportedChains, evmParams, btcParams, solParams +} + +func mustUpdateAppContextChainParams(t *testing.T, app *zctx.AppContext, chains []chains.Chain) { + supportedChain, evmParams, btcParams, solParams := chainParams(chains) + mustUpdateAppContext(t, app, supportedChain, evmParams, btcParams, solParams) +} + +func mustUpdateAppContext( + _ *testing.T, + app *zctx.AppContext, + chains []chains.Chain, + evmParams map[int64]*observertypes.ChainParams, + utxoParams *observertypes.ChainParams, + solParams *observertypes.ChainParams, +) { + app.Update( + ptr.Ptr(app.GetKeygen()), + chains, + evmParams, + utxoParams, + solParams, + app.GetCurrentTssPubKey(), + app.GetCrossChainFlags(), + app.GetAdditionalChains(), + nil, + false, + ) +} + +func hasSigner(t *testing.T, signers map[int64]interfaces.ChainSigner, chainId int64) { + signer, ok := signers[chainId] + assert.True(t, ok, "missing signer for chain %d", chainId) + assert.NotEmpty(t, signer) +} + +func missesSigner(t *testing.T, signers map[int64]interfaces.ChainSigner, chainId int64) { + _, ok := signers[chainId] + assert.False(t, ok, "unexpected signer for chain %d", chainId) +} + +func hasObserver(t *testing.T, observer map[int64]interfaces.ChainObserver, chainId int64) { + signer, ok := observer[chainId] + assert.True(t, ok, "missing observer for chain %d", chainId) + assert.NotEmpty(t, signer) +} + +func missesObserver(t *testing.T, observer map[int64]interfaces.ChainObserver, chainId int64) { + _, ok := observer[chainId] + assert.False(t, ok, "unexpected observer for chain %d", chainId) +} diff --git a/zetaclient/orchestrator/bootstrap.go b/zetaclient/orchestrator/bootstrap.go new file mode 100644 index 0000000000..5ce43a2ecd --- /dev/null +++ b/zetaclient/orchestrator/bootstrap.go @@ -0,0 +1,407 @@ +package orchestrator + +import ( + "context" + + ethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethclient" + solrpc "github.com/gagliardetto/solana-go/rpc" + "github.com/pkg/errors" + + "github.com/zeta-chain/zetacore/zetaclient/chains/base" + btcobserver "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/observer" + "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/rpc" + btcsigner "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/signer" + evmobserver "github.com/zeta-chain/zetacore/zetaclient/chains/evm/observer" + evmsigner "github.com/zeta-chain/zetacore/zetaclient/chains/evm/signer" + "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" + solbserver "github.com/zeta-chain/zetacore/zetaclient/chains/solana/observer" + "github.com/zeta-chain/zetacore/zetaclient/config" + zctx "github.com/zeta-chain/zetacore/zetaclient/context" + "github.com/zeta-chain/zetacore/zetaclient/db" + "github.com/zeta-chain/zetacore/zetaclient/metrics" +) + +// btcDatabaseFilename is the Bitcoin database file name now used in mainnet, +// so we keep using it here for backward compatibility +const btcDatabaseFilename = "btc_chain_client" + +// CreateSignerMap creates a map of interfaces.ChainSigner (by chainID) for all chains in the config. +// Note that signer construction failure for a chain does not prevent the creation of signers for other chains. +func CreateSignerMap( + ctx context.Context, + tss interfaces.TSSSigner, + logger base.Logger, + ts *metrics.TelemetryServer, +) (map[int64]interfaces.ChainSigner, error) { + signers := make(map[int64]interfaces.ChainSigner) + _, _, err := syncSignerMap(ctx, tss, logger, ts, &signers) + if err != nil { + return nil, err + } + + return signers, nil +} + +// syncSignerMap synchronizes the given signers map with the signers for all chains in the config. +// This semantic is used to allow dynamic updates to the signers map. +// Note that data race handling is the responsibility of the caller. +func syncSignerMap( + ctx context.Context, + tss interfaces.TSSSigner, + logger base.Logger, + ts *metrics.TelemetryServer, + signers *map[int64]interfaces.ChainSigner, +) (int, int, error) { + if signers == nil { + return 0, 0, errors.New("signers map is nil") + } + + app, err := zctx.FromContext(ctx) + if err != nil { + return 0, 0, errors.Wrapf(err, "failed to get app context") + } + + var ( + added, removed int + + presentChainIDs = make([]int64, 0) + + onAfterAdd = func(chainID int64, _ interfaces.ChainSigner) { + logger.Std.Info().Msgf("Added signer for chain %d", chainID) + added++ + } + + addSigner = func(chainID int64, signer interfaces.ChainSigner) { + mapSet[int64, interfaces.ChainSigner](signers, chainID, signer, onAfterAdd) + } + + onBeforeRemove = func(chainID int64, _ interfaces.ChainSigner) { + logger.Std.Info().Msgf("Removing signer for chain %d", chainID) + removed++ + } + ) + + // EVM signers + for _, evmConfig := range app.Config().GetAllEVMConfigs() { + chainID := evmConfig.Chain.ChainId + + evmChainParams, found := app.GetEVMChainParams(chainID) + switch { + case !found: + logger.Std.Warn().Msgf("Unable to find chain params for EVM chain %d", chainID) + continue + case !evmChainParams.IsSupported: + logger.Std.Warn().Msgf("EVM chain %d is not supported", chainID) + continue + } + + presentChainIDs = append(presentChainIDs, chainID) + + // noop for existing signers + if mapHas(signers, chainID) { + continue + } + + var ( + mpiAddress = ethcommon.HexToAddress(evmChainParams.ConnectorContractAddress) + erc20CustodyAddress = ethcommon.HexToAddress(evmChainParams.Erc20CustodyContractAddress) + ) + + signer, err := evmsigner.NewSigner( + ctx, + evmConfig.Chain, + tss, + ts, + logger, + evmConfig.Endpoint, + config.GetConnectorABI(), + config.GetERC20CustodyABI(), + mpiAddress, + erc20CustodyAddress, + ) + if err != nil { + logger.Std.Error().Err(err).Msgf("Unable to construct signer for EVM chain %d", chainID) + continue + } + + addSigner(chainID, signer) + } + + // BTC signer + // Emulate same loop semantics as for EVM chains + for i := 0; i < 1; i++ { + btcChain, btcChainParams, btcChainParamsFound := app.GetBTCChainParams() + switch { + case !btcChainParamsFound: + logger.Std.Warn().Msgf("Unable to find chain params for BTC chain") + continue + case !btcChainParams.IsSupported: + logger.Std.Warn().Msgf("BTC chain is not supported") + continue + } + + chainID := btcChainParams.ChainId + + presentChainIDs = append(presentChainIDs, chainID) + + // noop + if mapHas(signers, chainID) { + continue + } + + cfg, _ := app.Config().GetBTCConfig() + + utxoSigner, err := btcsigner.NewSigner(btcChain, tss, ts, logger, cfg) + if err != nil { + logger.Std.Error().Err(err).Msgf("Unable to construct signer for UTXO chain %d", chainID) + continue + } + + addSigner(chainID, utxoSigner) + } + + // Remove all disabled signers + mapDeleteMissingKeys(signers, presentChainIDs, onBeforeRemove) + + return added, removed, nil +} + +// CreateChainObserverMap creates a map of interfaces.ChainObserver (by chainID) for all chains in the config. +// - Note (!) that it calls observer.Start() on creation +// - Note that data race handling is the responsibility of the caller. +func CreateChainObserverMap( + ctx context.Context, + client interfaces.ZetacoreClient, + tss interfaces.TSSSigner, + dbpath string, + logger base.Logger, + ts *metrics.TelemetryServer, +) (map[int64]interfaces.ChainObserver, error) { + observerMap := make(map[int64]interfaces.ChainObserver) + + _, _, err := syncObserverMap(ctx, client, tss, dbpath, logger, ts, &observerMap) + if err != nil { + return nil, err + } + + return observerMap, nil +} + +// syncObserverMap synchronizes the given observer map with the observers for all chains in the config. +// This semantic is used to allow dynamic updates to the map. +// Note (!) that it calls observer.Start() on creation and observer.Stop() on deletion. +func syncObserverMap( + ctx context.Context, + client interfaces.ZetacoreClient, + tss interfaces.TSSSigner, + dbpath string, + logger base.Logger, + ts *metrics.TelemetryServer, + observerMap *map[int64]interfaces.ChainObserver, +) (int, int, error) { + app, err := zctx.FromContext(ctx) + if err != nil { + return 0, 0, errors.Wrapf(err, "failed to get app context") + } + + var ( + added, removed int + + presentChainIDs = make([]int64, 0) + + onAfterAdd = func(_ int64, ob interfaces.ChainObserver) { + ob.Start(ctx) + added++ + } + + addObserver = func(chainID int64, ob interfaces.ChainObserver) { + mapSet[int64, interfaces.ChainObserver](observerMap, chainID, ob, onAfterAdd) + } + + onBeforeRemove = func(_ int64, ob interfaces.ChainObserver) { + ob.Stop() + removed++ + } + ) + + // EVM observers + for _, evmConfig := range app.Config().GetAllEVMConfigs() { + var ( + chainID = evmConfig.Chain.ChainId + chainName = evmConfig.Chain.ChainName.String() + ) + + chainParams, found := app.GetEVMChainParams(evmConfig.Chain.ChainId) + switch { + case !found: + logger.Std.Error().Msgf("Unable to find chain params for EVM chain %d", chainID) + continue + case !chainParams.IsSupported: + logger.Std.Error().Msgf("EVM chain %d is not supported", chainID) + continue + } + + presentChainIDs = append(presentChainIDs, chainID) + + // noop + if mapHas(observerMap, chainID) { + continue + } + + // create EVM client + evmClient, err := ethclient.DialContext(ctx, evmConfig.Endpoint) + if err != nil { + logger.Std.Error().Err(err).Str("rpc.endpoint", evmConfig.Endpoint).Msgf("Unable to dial EVM RPC") + continue + } + + database, err := db.NewFromSqlite(dbpath, chainName, true) + if err != nil { + logger.Std.Error().Err(err).Msgf("Unable to open a database for EVM chain %q", chainName) + continue + } + + // create EVM chain observer + observer, err := evmobserver.NewObserver( + ctx, + evmConfig, + evmClient, + *chainParams, + client, + tss, + database, + logger, + ts, + ) + if err != nil { + logger.Std.Error().Err(err).Msgf("NewObserver error for EVM chain %s", evmConfig.Chain.String()) + continue + } + + addObserver(chainID, observer) + } + + // Emulate same loop semantics as for EVM chains + // create BTC chain observer + for i := 0; i < 1; i++ { + btcChain, btcConfig, btcEnabled := app.GetBTCChainAndConfig() + if !btcEnabled { + continue + } + + chainID := btcChain.ChainId + + _, btcChainParams, found := app.GetBTCChainParams() + switch { + case !found: + logger.Std.Warn().Msgf("Unable to find chain params for BTC chain %d", chainID) + continue + case !btcChainParams.IsSupported: + logger.Std.Warn().Msgf("BTC chain %d is not supported", chainID) + continue + } + + presentChainIDs = append(presentChainIDs, chainID) + + // noop + if mapHas(observerMap, chainID) { + continue + } + + btcRPC, err := rpc.NewRPCClient(btcConfig) + if err != nil { + logger.Std.Error().Err(err).Msgf("unable to create rpc client for BTC chain %d", chainID) + continue + } + + database, err := db.NewFromSqlite(dbpath, btcDatabaseFilename, true) + if err != nil { + logger.Std.Error().Err(err).Msgf("unable to open database for BTC chain %d", chainID) + continue + } + + btcObserver, err := btcobserver.NewObserver( + btcChain, + btcRPC, + *btcChainParams, + client, + tss, + database, + logger, + ts, + ) + if err != nil { + logger.Std.Error().Err(err).Msgf("NewObserver error for BTC chain %d", chainID) + continue + } + + addObserver(chainID, btcObserver) + } + + // Emulate same loop semantics as for EVM chains + // create SOL chain observer + for i := 0; i < 1; i++ { + solChain, solConfig, solEnabled := app.GetSolanaChainAndConfig() + if !solEnabled { + continue + } + + var ( + chainID = solChain.ChainId + chainName = solChain.ChainName.String() + ) + + _, solanaChainParams, found := app.GetSolanaChainParams() + switch { + case !found: + logger.Std.Warn().Msgf("Unable to find chain params for SOL chain %d", chainID) + continue + case !solanaChainParams.IsSupported: + logger.Std.Warn().Msgf("SOL chain %d is not supported", chainID) + continue + } + + presentChainIDs = append(presentChainIDs, chainID) + + // noop + if mapHas(observerMap, chainID) { + continue + } + + rpcClient := solrpc.New(solConfig.Endpoint) + if rpcClient == nil { + // should never happen + logger.Std.Error().Msg("solana create Solana client error") + continue + } + + database, err := db.NewFromSqlite(dbpath, chainName, true) + if err != nil { + logger.Std.Error().Err(err).Msgf("unable to open database for SOL chain %d", chainID) + continue + } + + solObserver, err := solbserver.NewObserver( + solChain, + rpcClient, + *solanaChainParams, + client, + tss, + database, + logger, + ts, + ) + if err != nil { + logger.Std.Error().Err(err).Msgf("NewObserver error for SOL chain %d", chainID) + continue + } + + addObserver(chainID, solObserver) + } + + // Remove all disabled observers + mapDeleteMissingKeys(observerMap, presentChainIDs, onBeforeRemove) + + return added, removed, nil +} diff --git a/zetaclient/orchestrator/mapping.go b/zetaclient/orchestrator/mapping.go new file mode 100644 index 0000000000..bb2a8d3774 --- /dev/null +++ b/zetaclient/orchestrator/mapping.go @@ -0,0 +1,53 @@ +package orchestrator + +import "cmp" + +// This is a collection of generic functions that can be used to manipulate maps BY pointer. +// This is useful for observers/signers because we want to operate with the same map and not a copy. +// It simplifies code semantics by hiding complexity of accessing map by a pointer. + +// mapHas checks if the map contains the given key. +func mapHas[K cmp.Ordered, V any](m *map[K]V, key K) bool { + _, ok := (*m)[key] + return ok +} + +// mapSet sets the value for the given key in the map +// and (optionally) runs a callback after setting the value. +func mapSet[K cmp.Ordered, V any](m *map[K]V, key K, value V, afterSet func(K, V)) { + (*m)[key] = value + + if afterSet != nil { + afterSet(key, value) + } +} + +// mapUnset removes the value for the given key from the map (if exists) +// and optionally runs a callback before removing the value. +func mapUnset[K cmp.Ordered, V any](m *map[K]V, key K, beforeUnset func(K, V)) bool { + if !mapHas(m, key) { + return false + } + + if beforeUnset != nil { + beforeUnset(key, (*m)[key]) + } + + delete(*m, key) + + return true +} + +// mapDeleteMissingKeys removes elements from the map IF they are not in the presentKeys. +func mapDeleteMissingKeys[K cmp.Ordered, V any](m *map[K]V, presentKeys []K, beforeUnset func(K, V)) { + set := make(map[K]struct{}, len(presentKeys)) + for _, id := range presentKeys { + set[id] = struct{}{} + } + + for key := range *m { + if _, found := set[key]; !found { + mapUnset(m, key, beforeUnset) + } + } +} diff --git a/zetaclient/orchestrator/orchestrator.go b/zetaclient/orchestrator/orchestrator.go index 951c4146f6..25057b9f1f 100644 --- a/zetaclient/orchestrator/orchestrator.go +++ b/zetaclient/orchestrator/orchestrator.go @@ -5,10 +5,12 @@ import ( "context" "fmt" "math" + "sync" "time" sdkmath "cosmossdk.io/math" ethcommon "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/zeta-chain/zetacore/pkg/bg" @@ -16,7 +18,9 @@ import ( zetamath "github.com/zeta-chain/zetacore/pkg/math" "github.com/zeta-chain/zetacore/x/crosschain/types" observertypes "github.com/zeta-chain/zetacore/x/observer/types" + "github.com/zeta-chain/zetacore/zetaclient/chains/base" btcobserver "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin/observer" + "github.com/zeta-chain/zetacore/zetaclient/chains/evm" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" zctx "github.com/zeta-chain/zetacore/zetaclient/context" "github.com/zeta-chain/zetacore/zetaclient/metrics" @@ -35,12 +39,7 @@ const ( loggerSamplingRate = 10 ) -// Log is a struct that contains the logger -// TODO(revamp): rename to logger -type Log struct { - Std zerolog.Logger - Sampled zerolog.Logger -} +var defaultLogSampler = &zerolog.BasicSampler{N: loggerSamplingRate} // Orchestrator wraps the zetacore client, chain observers and signers. This is the high level object used for CCTX scheduling type Orchestrator struct { @@ -59,59 +58,80 @@ type Orchestrator struct { // last operator balance lastOperatorBalance sdkmath.Int + // observer & signer props + tss interfaces.TSSSigner + dbDirectory string + baseLogger base.Logger + // misc - logger Log - stop chan struct{} + logger multiLogger ts *metrics.TelemetryServer + stop chan struct{} + mu sync.RWMutex } -// NewOrchestrator creates a new orchestrator -func NewOrchestrator( +type multiLogger struct { + zerolog.Logger + Sampled zerolog.Logger +} + +// New creates a new Orchestrator +func New( ctx context.Context, - zetacoreClient interfaces.ZetacoreClient, + client interfaces.ZetacoreClient, signerMap map[int64]interfaces.ChainSigner, observerMap map[int64]interfaces.ChainObserver, - logger zerolog.Logger, + tss interfaces.TSSSigner, + dbDirectory string, + logger base.Logger, ts *metrics.TelemetryServer, -) *Orchestrator { - oc := Orchestrator{ - ts: ts, - stop: make(chan struct{}), +) (*Orchestrator, error) { + if signerMap == nil || observerMap == nil { + return nil, errors.New("signerMap or observerMap is nil") } - // create loggers - oc.logger = Log{ - Std: logger.With().Str("module", "Orchestrator").Logger(), + log := multiLogger{ + Logger: logger.Std.With().Str("module", "orchestrator").Logger(), + Sampled: logger.Std.With().Str("module", "orchestrator").Logger().Sample(defaultLogSampler), } - oc.logger.Sampled = oc.logger.Std.Sample(&zerolog.BasicSampler{N: loggerSamplingRate}) - - // set zetacore client, signers and chain observers - oc.zetacoreClient = zetacoreClient - oc.signerMap = signerMap - oc.observerMap = observerMap - - // create outbound processor - oc.outboundProc = outboundprocessor.NewProcessor(logger) - balance, err := zetacoreClient.GetZetaHotKeyBalance(ctx) + balance, err := client.GetZetaHotKeyBalance(ctx) if err != nil { - oc.logger.Std.Error().Err(err).Msg("error getting last balance of the hot key") + return nil, errors.Wrap(err, "unable to get last balance of the hot key") } - oc.lastOperatorBalance = balance - return &oc + return &Orchestrator{ + zetacoreClient: client, + + signerMap: signerMap, + observerMap: observerMap, + + outboundProc: outboundprocessor.NewProcessor(logger.Std), + lastOperatorBalance: balance, + + // observer & signer props + tss: tss, + dbDirectory: dbDirectory, + baseLogger: logger, + + logger: log, + ts: ts, + stop: make(chan struct{}), + }, nil } -// MonitorCore starts the orchestrator for CCTXs -func (oc *Orchestrator) MonitorCore(ctx context.Context) error { +// Start starts the orchestrator for CCTXs. +func (oc *Orchestrator) Start(ctx context.Context) error { signerAddress, err := oc.zetacoreClient.GetKeys().GetAddress() if err != nil { - return fmt.Errorf("failed to get signer address: %w", err) + return errors.Wrap(err, "unable to get signer address") } - oc.logger.Std.Info().Msgf("Starting orchestrator for signer: %s", signerAddress) + + oc.logger.Info().Str("signer", signerAddress.String()).Msg("Starting orchestrator") // start cctx scheduler - bg.Work(ctx, oc.StartCctxScheduler, bg.WithName("StartCctxScheduler"), bg.WithLogger(oc.logger.Std)) + bg.Work(ctx, oc.runScheduler, bg.WithName("runScheduler"), bg.WithLogger(oc.logger.Logger)) + bg.Work(ctx, oc.runObserverSignerSync, bg.WithName("runObserverSignerSync"), bg.WithLogger(oc.logger.Logger)) shutdownOrchestrator := func() { // now stop orchestrator and all observers @@ -126,67 +146,97 @@ func (oc *Orchestrator) MonitorCore(ctx context.Context) error { return nil } -// GetUpdatedSigner returns signer with updated chain parameters -func (oc *Orchestrator) GetUpdatedSigner( - appContext *zctx.AppContext, - chainID int64, -) (interfaces.ChainSigner, error) { - signer, found := oc.signerMap[chainID] +// returns signer with updated chain parameters. +func (oc *Orchestrator) resolveSigner(app *zctx.AppContext, chainID int64) (interfaces.ChainSigner, error) { + signer, err := oc.getSigner(chainID) + if err != nil { + return nil, err + } + + // noop for non-EVM chains + if !chains.IsEVMChain(chainID, app.GetAdditionalChains()) { + return signer, nil + } + + evmParams, found := app.GetEVMChainParams(chainID) if !found { - return nil, fmt.Errorf("signer not found for chainID %d", chainID) + return signer, nil } - // update EVM signer parameters only. BTC signer doesn't use chain parameters for now. - if chains.IsEVMChain(chainID, appContext.GetAdditionalChains()) { - evmParams, found := appContext.GetEVMChainParams(chainID) - if found { - // update zeta connector and ERC20 custody addresses - zetaConnectorAddress := ethcommon.HexToAddress(evmParams.GetConnectorContractAddress()) - erc20CustodyAddress := ethcommon.HexToAddress(evmParams.GetErc20CustodyContractAddress()) - if zetaConnectorAddress != signer.GetZetaConnectorAddress() { - signer.SetZetaConnectorAddress(zetaConnectorAddress) - oc.logger.Std.Info().Msgf( - "updated zeta connector address for chainID %d, new address: %s", chainID, zetaConnectorAddress) - } - if erc20CustodyAddress != signer.GetERC20CustodyAddress() { - signer.SetERC20CustodyAddress(erc20CustodyAddress) - oc.logger.Std.Info().Msgf( - "updated ERC20 custody address for chainID %d, new address: %s", chainID, erc20CustodyAddress) - } - } + + // update zeta connector and ERC20 custody addresses + zetaConnectorAddress := ethcommon.HexToAddress(evmParams.GetConnectorContractAddress()) + if zetaConnectorAddress != signer.GetZetaConnectorAddress() { + signer.SetZetaConnectorAddress(zetaConnectorAddress) + oc.logger.Info(). + Str("signer.connector_address", zetaConnectorAddress.String()). + Msgf("updated zeta connector address for chain %d", chainID) + } + + erc20CustodyAddress := ethcommon.HexToAddress(evmParams.GetErc20CustodyContractAddress()) + if erc20CustodyAddress != signer.GetERC20CustodyAddress() { + signer.SetERC20CustodyAddress(erc20CustodyAddress) + oc.logger.Info(). + Str("signer.erc20_custody", erc20CustodyAddress.String()). + Msgf("updated zeta connector address for chain %d", chainID) } + return signer, nil } -// GetUpdatedChainObserver returns chain observer with updated chain parameters -func (oc *Orchestrator) GetUpdatedChainObserver( - appContext *zctx.AppContext, - chainID int64, -) (interfaces.ChainObserver, error) { - observer, found := oc.observerMap[chainID] +func (oc *Orchestrator) getSigner(chainID int64) (interfaces.ChainSigner, error) { + oc.mu.RLock() + defer oc.mu.RUnlock() + + s, found := oc.signerMap[chainID] if !found { - return nil, fmt.Errorf("chain observer not found for chainID %d", chainID) + return nil, fmt.Errorf("signer not found for chainID %d", chainID) + } + + return s, nil +} + +// returns chain observer with updated chain parameters +func (oc *Orchestrator) resolveObserver(app *zctx.AppContext, chainID int64) (interfaces.ChainObserver, error) { + observer, err := oc.getObserver(chainID) + if err != nil { + return nil, err } + // update chain observer chain parameters curParams := observer.GetChainParams() - if chains.IsEVMChain(chainID, appContext.GetAdditionalChains()) { - evmParams, found := appContext.GetEVMChainParams(chainID) + if chains.IsEVMChain(chainID, app.GetAdditionalChains()) { + evmParams, found := app.GetEVMChainParams(chainID) if found && !observertypes.ChainParamsEqual(curParams, *evmParams) { observer.SetChainParams(*evmParams) - oc.logger.Std.Info().Msgf( - "updated chain params for chainID %d, new params: %v", chainID, *evmParams) + oc.logger.Info(). + Interface("observer.chain_params", *evmParams). + Msgf("updated chain params for EVM chainID %d", chainID) } - } else if chains.IsBitcoinChain(chainID, appContext.GetAdditionalChains()) { - _, btcParams, found := appContext.GetBTCChainParams() - + } else if chains.IsBitcoinChain(chainID, app.GetAdditionalChains()) { + _, btcParams, found := app.GetBTCChainParams() if found && !observertypes.ChainParamsEqual(curParams, *btcParams) { observer.SetChainParams(*btcParams) - oc.logger.Std.Info().Msgf( - "updated chain params for Bitcoin, new params: %v", *btcParams) + oc.logger.Info(). + Interface("observer.chain_params", *btcParams). + Msgf("updated chain params for UTXO chainID %d", btcParams.ChainId) } } + return observer, nil } +func (oc *Orchestrator) getObserver(chainID int64) (interfaces.ChainObserver, error) { + oc.mu.RLock() + defer oc.mu.RUnlock() + + ob, found := oc.observerMap[chainID] + if !found { + return nil, fmt.Errorf("observer not found for chainID %d", chainID) + } + + return ob, nil +} + // GetPendingCctxsWithinRateLimit get pending cctxs across foreign chains within rate limit func (oc *Orchestrator) GetPendingCctxsWithinRateLimit( ctx context.Context, @@ -238,9 +288,9 @@ func (oc *Orchestrator) GetPendingCctxsWithinRateLimit( return output.CctxsMap, nil } -// StartCctxScheduler schedules keysigns for cctxs on each ZetaChain block (the ticker) +// schedules keysigns for cctxs on each ZetaChain block (the ticker) // TODO(revamp): make this function simpler -func (oc *Orchestrator) StartCctxScheduler(ctx context.Context) error { +func (oc *Orchestrator) runScheduler(ctx context.Context) error { app, err := zctx.FromContext(ctx) if err != nil { return err @@ -251,17 +301,17 @@ func (oc *Orchestrator) StartCctxScheduler(ctx context.Context) error { for { select { case <-oc.stop: - oc.logger.Std.Warn().Msg("StartCctxScheduler: stopped") + oc.logger.Warn().Msg("runScheduler: stopped") return nil case <-observeTicker.C: { bn, err := oc.zetacoreClient.GetBlockHeight(ctx) if err != nil { - oc.logger.Std.Error().Err(err).Msg("StartCctxScheduler: GetBlockHeight fail") + oc.logger.Error().Err(err).Msg("StartCctxScheduler: GetBlockHeight fail") continue } if bn < 0 { - oc.logger.Std.Error().Msg("StartCctxScheduler: GetBlockHeight returned negative height") + oc.logger.Error().Msg("runScheduler: GetBlockHeight returned negative height") continue } if lastBlockNum == 0 { @@ -270,12 +320,12 @@ func (oc *Orchestrator) StartCctxScheduler(ctx context.Context) error { if bn > lastBlockNum { // we have a new block bn = lastBlockNum + 1 if bn%10 == 0 { - oc.logger.Std.Debug().Msgf("StartCctxScheduler: zetacore heart beat: %d", bn) + oc.logger.Debug().Msgf("runScheduler: zetacore heart beat: %d", bn) } balance, err := oc.zetacoreClient.GetZetaHotKeyBalance(ctx) if err != nil { - oc.logger.Std.Error().Err(err).Msgf("couldn't get operator balance") + oc.logger.Error().Err(err).Msgf("couldn't get operator balance") } else { diff := oc.lastOperatorBalance.Sub(balance) if diff.GT(sdkmath.NewInt(0)) && diff.LT(sdkmath.NewInt(math.MaxInt64)) { @@ -293,7 +343,7 @@ func (oc *Orchestrator) StartCctxScheduler(ctx context.Context) error { // query pending cctxs across all external chains within rate limit cctxMap, err := oc.GetPendingCctxsWithinRateLimit(ctx, externalChains) if err != nil { - oc.logger.Std.Error().Err(err).Msgf("StartCctxScheduler: GetPendingCctxsWithinRatelimit failed") + oc.logger.Error().Err(err).Msgf("runScheduler: GetPendingCctxsWithinRatelimit failed") } // schedule keysign for pending cctxs on each chain @@ -306,18 +356,16 @@ func (oc *Orchestrator) StartCctxScheduler(ctx context.Context) error { } // update chain parameters for signer and chain observer - signer, err := oc.GetUpdatedSigner(app, c.ChainId) + signer, err := oc.resolveSigner(app, c.ChainId) if err != nil { - oc.logger.Std.Error(). - Err(err). - Msgf("StartCctxScheduler: GetUpdatedSigner failed for chain %d", c.ChainId) + oc.logger.Error().Err(err). + Msgf("runScheduler: unable to resolve signer for chain %d", c.ChainId) continue } - ob, err := oc.GetUpdatedChainObserver(app, c.ChainId) + ob, err := oc.resolveObserver(app, c.ChainId) if err != nil { - oc.logger.Std.Error(). - Err(err). - Msgf("StartCctxScheduler: GetUpdatedChainObserver failed for chain %d", c.ChainId) + oc.logger.Error().Err(err). + Msgf("runScheduler: resolveObserver failed for chain %d", c.ChainId) continue } if !app.IsOutboundObservationEnabled(ob.GetChainParams()) { @@ -331,7 +379,7 @@ func (oc *Orchestrator) StartCctxScheduler(ctx context.Context) error { } else if chains.IsBitcoinChain(c.ChainId, app.GetAdditionalChains()) { oc.ScheduleCctxBTC(ctx, zetaHeight, c.ChainId, cctxList, ob, signer) } else { - oc.logger.Std.Error().Msgf("StartCctxScheduler: unsupported chain %d", c.ChainId) + oc.logger.Error().Msgf("StartCctxScheduler: unsupported chain %d", c.ChainId) continue } } @@ -356,7 +404,7 @@ func (oc *Orchestrator) ScheduleCctxEVM( ) { res, err := oc.zetacoreClient.GetAllOutboundTrackerByChain(ctx, chainID, interfaces.Ascending) if err != nil { - oc.logger.Std.Warn().Err(err).Msgf("ScheduleCctxEVM: GetAllOutboundTrackerByChain failed for chain %d", chainID) + oc.logger.Warn().Err(err).Msgf("ScheduleCctxEVM: GetAllOutboundTrackerByChain failed for chain %d", chainID) return } trackerMap := make(map[uint64]bool) @@ -375,26 +423,26 @@ func (oc *Orchestrator) ScheduleCctxEVM( outboundID := outboundprocessor.ToOutboundID(cctx.Index, params.ReceiverChainId, nonce) if params.ReceiverChainId != chainID { - oc.logger.Std.Error(). + oc.logger.Error(). Msgf("ScheduleCctxEVM: outbound %s chainid mismatch: want %d, got %d", outboundID, chainID, params.ReceiverChainId) continue } if params.TssNonce > cctxList[0].GetCurrentOutboundParam().TssNonce+outboundScheduleLookback { - oc.logger.Std.Error().Msgf("ScheduleCctxEVM: nonce too high: signing %d, earliest pending %d, chain %d", + oc.logger.Error().Msgf("ScheduleCctxEVM: nonce too high: signing %d, earliest pending %d, chain %d", params.TssNonce, cctxList[0].GetCurrentOutboundParam().TssNonce, chainID) break } // try confirming the outbound - included, _, err := observer.IsOutboundProcessed(ctx, cctx, oc.logger.Std) + included, _, err := observer.IsOutboundProcessed(ctx, cctx, oc.logger.Logger) if err != nil { - oc.logger.Std.Error(). + oc.logger.Error(). Err(err). Msgf("ScheduleCctxEVM: IsOutboundProcessed faild for chain %d nonce %d", chainID, nonce) continue } if included { - oc.logger.Std.Info(). + oc.logger.Info(). Msgf("ScheduleCctxEVM: outbound %s already included; do not schedule keysign", outboundID) continue } @@ -424,7 +472,7 @@ func (oc *Orchestrator) ScheduleCctxEVM( if nonce%outboundScheduleInterval == zetaHeight%outboundScheduleInterval && !oc.outboundProc.IsOutboundActive(outboundID) { oc.outboundProc.StartTryProcess(outboundID) - oc.logger.Std.Debug(). + oc.logger.Debug(). Msgf("ScheduleCctxEVM: sign outbound %s with value %d\n", outboundID, cctx.GetCurrentOutboundParam().Amount) go signer.TryProcessOutbound( ctx, @@ -458,7 +506,7 @@ func (oc *Orchestrator) ScheduleCctxBTC( ) { btcObserver, ok := observer.(*btcobserver.Observer) if !ok { // should never happen - oc.logger.Std.Error().Msgf("ScheduleCctxBTC: chain observer is not a bitcoin observer") + oc.logger.Error().Msgf("ScheduleCctxBTC: chain observer is not a bitcoin observer") return } // #nosec G115 positive @@ -472,20 +520,20 @@ func (oc *Orchestrator) ScheduleCctxBTC( outboundID := outboundprocessor.ToOutboundID(cctx.Index, params.ReceiverChainId, nonce) if params.ReceiverChainId != chainID { - oc.logger.Std.Error(). + oc.logger.Error(). Msgf("ScheduleCctxBTC: outbound %s chainid mismatch: want %d, got %d", outboundID, chainID, params.ReceiverChainId) continue } // try confirming the outbound - included, confirmed, err := btcObserver.IsOutboundProcessed(ctx, cctx, oc.logger.Std) + included, confirmed, err := btcObserver.IsOutboundProcessed(ctx, cctx, oc.logger.Logger) if err != nil { - oc.logger.Std.Error(). + oc.logger.Error(). Err(err). Msgf("ScheduleCctxBTC: IsOutboundProcessed faild for chain %d nonce %d", chainID, nonce) continue } if included || confirmed { - oc.logger.Std.Info(). + oc.logger.Info(). Msgf("ScheduleCctxBTC: outbound %s already included; do not schedule keysign", outboundID) continue } @@ -498,14 +546,14 @@ func (oc *Orchestrator) ScheduleCctxBTC( if int64( idx, ) >= lookahead { // 2 bitcoin confirmations span is 20 minutes on average. We look ahead up to 100 pending cctx to target TPM of 5. - oc.logger.Std.Warn(). + oc.logger.Warn(). Msgf("ScheduleCctxBTC: lookahead reached, signing %d, earliest pending %d", nonce, cctxList[0].GetCurrentOutboundParam().TssNonce) break } // try confirming the outbound or scheduling a keysign if nonce%interval == zetaHeight%interval && !oc.outboundProc.IsOutboundActive(outboundID) { oc.outboundProc.StartTryProcess(outboundID) - oc.logger.Std.Debug().Msgf("ScheduleCctxBTC: sign outbound %s with value %d\n", outboundID, params.Amount) + oc.logger.Debug().Msgf("ScheduleCctxBTC: sign outbound %s with value %d\n", outboundID, params.Amount) go signer.TryProcessOutbound( ctx, cctx, @@ -518,3 +566,61 @@ func (oc *Orchestrator) ScheduleCctxBTC( } } } + +// runObserverSignerSync runs a blocking ticker that observes chain changes from zetacore +// and optionally (de)provisions respective observers and signers. +func (oc *Orchestrator) runObserverSignerSync(ctx context.Context) error { + // check every other zeta block + const cadence = 2 * evm.ZetaBlockTime + + ticker := time.NewTicker(cadence) + defer ticker.Stop() + + for { + select { + case <-oc.stop: + oc.logger.Warn().Msg("runObserverSignerSync: stopped") + return nil + case <-ticker.C: + if err := oc.syncObserverSigner(ctx); err != nil { + oc.logger.Error().Err(err).Msg("runObserverSignerSync: syncObserverSigner failed") + } + } + } +} + +// syncs and provisions observers & signers. +// Note that zctx.AppContext Update is a responsibility of another component +// See zetacore.Client{}.UpdateZetacoreContextWorker +func (oc *Orchestrator) syncObserverSigner(ctx context.Context) error { + oc.mu.Lock() + defer oc.mu.Unlock() + + client := oc.zetacoreClient + + added, removed, err := syncObserverMap(ctx, client, oc.tss, oc.dbDirectory, oc.baseLogger, oc.ts, &oc.observerMap) + if err != nil { + return errors.Wrap(err, "syncObserverMap failed") + } + + if added+removed > 0 { + oc.logger.Info(). + Int("observer.added", added). + Int("observer.removed", removed). + Msg("synced observers") + } + + added, removed, err = syncSignerMap(ctx, oc.tss, oc.baseLogger, oc.ts, &oc.signerMap) + if err != nil { + return errors.Wrap(err, "syncSignerMap failed") + } + + if added+removed > 0 { + oc.logger.Info(). + Int("signers.added", added). + Int("signers.removed", removed). + Msg("synced signers") + } + + return nil +} diff --git a/zetaclient/orchestrator/orchestrator_test.go b/zetaclient/orchestrator/orchestrator_test.go index 06e7aaf625..4115dee69a 100644 --- a/zetaclient/orchestrator/orchestrator_test.go +++ b/zetaclient/orchestrator/orchestrator_test.go @@ -112,14 +112,14 @@ func Test_GetUpdatedSigner(t *testing.T) { orchestrator := MockOrchestrator(t, nil, evmChain, btcChain, evmChainParams, btcChainParams) context := CreateAppContext(evmChain, btcChain, evmChainParamsNew, btcChainParams) // BSC signer should not be found - _, err := orchestrator.GetUpdatedSigner(context, chains.BscMainnet.ChainId) + _, err := orchestrator.resolveSigner(context, chains.BscMainnet.ChainId) require.ErrorContains(t, err, "signer not found") }) t.Run("should be able to update connector and erc20 custody address", func(t *testing.T) { orchestrator := MockOrchestrator(t, nil, evmChain, btcChain, evmChainParams, btcChainParams) context := CreateAppContext(evmChain, btcChain, evmChainParamsNew, btcChainParams) // update signer with new connector and erc20 custody address - signer, err := orchestrator.GetUpdatedSigner(context, evmChain.ChainId) + signer, err := orchestrator.resolveSigner(context, evmChain.ChainId) require.NoError(t, err) require.Equal(t, testutils.OtherAddress1, signer.GetZetaConnectorAddress().Hex()) require.Equal(t, testutils.OtherAddress2, signer.GetERC20CustodyAddress().Hex()) @@ -177,14 +177,14 @@ func Test_GetUpdatedChainObserver(t *testing.T) { orchestrator := MockOrchestrator(t, nil, evmChain, btcChain, evmChainParams, btcChainParams) coreContext := CreateAppContext(evmChain, btcChain, evmChainParamsNew, btcChainParams) // BSC chain observer should not be found - _, err := orchestrator.GetUpdatedChainObserver(coreContext, chains.BscMainnet.ChainId) - require.ErrorContains(t, err, "chain observer not found") + _, err := orchestrator.resolveObserver(coreContext, chains.BscMainnet.ChainId) + require.ErrorContains(t, err, "observer not found") }) t.Run("chain params in evm chain observer should be updated successfully", func(t *testing.T) { orchestrator := MockOrchestrator(t, nil, evmChain, btcChain, evmChainParams, btcChainParams) coreContext := CreateAppContext(evmChain, btcChain, evmChainParamsNew, btcChainParams) // update evm chain observer with new chain params - chainOb, err := orchestrator.GetUpdatedChainObserver(coreContext, evmChain.ChainId) + chainOb, err := orchestrator.resolveObserver(coreContext, evmChain.ChainId) require.NoError(t, err) require.NotNil(t, chainOb) require.True(t, observertypes.ChainParamsEqual(*evmChainParamsNew, chainOb.GetChainParams())) @@ -193,14 +193,14 @@ func Test_GetUpdatedChainObserver(t *testing.T) { orchestrator := MockOrchestrator(t, nil, evmChain, btcChain, evmChainParams, btcChainParams) coreContext := CreateAppContext(btcChain, btcChain, evmChainParams, btcChainParamsNew) // BTC testnet chain observer should not be found - _, err := orchestrator.GetUpdatedChainObserver(coreContext, chains.BitcoinTestnet.ChainId) - require.ErrorContains(t, err, "chain observer not found") + _, err := orchestrator.resolveObserver(coreContext, chains.BitcoinTestnet.ChainId) + require.ErrorContains(t, err, "observer not found") }) t.Run("chain params in btc chain observer should be updated successfully", func(t *testing.T) { orchestrator := MockOrchestrator(t, nil, evmChain, btcChain, evmChainParams, btcChainParams) coreContext := CreateAppContext(btcChain, btcChain, evmChainParams, btcChainParamsNew) // update btc chain observer with new chain params - chainOb, err := orchestrator.GetUpdatedChainObserver(coreContext, btcChain.ChainId) + chainOb, err := orchestrator.resolveObserver(coreContext, btcChain.ChainId) require.NoError(t, err) require.NotNil(t, chainOb) require.True(t, observertypes.ChainParamsEqual(*btcChainParamsNew, chainOb.GetChainParams())) diff --git a/zetaclient/testutils/constant.go b/zetaclient/testutils/constant.go index 3d4f6e2a03..ad8302577d 100644 --- a/zetaclient/testutils/constant.go +++ b/zetaclient/testutils/constant.go @@ -27,9 +27,6 @@ const ( EventZetaReverted = "ZetaReverted" EventERC20Deposit = "Deposited" EventERC20Withdraw = "Withdrawn" - - // SQLiteMemory is a SQLite in-memory database connection string. - SQLiteMemory = "file::memory:?cache=shared" ) // ConnectorAddresses contains constants ERC20 connector addresses for testing diff --git a/zetaclient/testutils/testrpc/rpc.go b/zetaclient/testutils/testrpc/rpc.go new file mode 100644 index 0000000000..f444631813 --- /dev/null +++ b/zetaclient/testutils/testrpc/rpc.go @@ -0,0 +1,87 @@ +package testrpc + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +// Server represents JSON RPC mock with a "real" HTTP server allocated (httptest) +type Server struct { + t *testing.T + handlers map[string]func(params []any) (any, error) + name string +} + +// New constructs Server. +func New(t *testing.T, name string) (*Server, string) { + var ( + handlers = make(map[string]func(params []any) (any, error)) + rpc = &Server{t, handlers, name} + testWeb = httptest.NewServer(http.HandlerFunc(rpc.httpHandler)) + ) + + t.Cleanup(testWeb.Close) + + return rpc, testWeb.URL +} + +// On registers a handler for a given method. +func (s *Server) On(method string, call func(params []any) (any, error)) { + s.handlers[method] = call +} + +// example: {"jsonrpc":"1.0","method":"ping","params":[],"id":1} +type rpcRequest struct { + Method string `json:"method"` + Params []any `json:"params"` +} + +// example: {"result":0,"error":null,"id":"curltest"} +type rpcResponse struct { + Result any `json:"result"` + Error error `json:"error"` +} + +// handler is a simple HTTP handler that returns 200 OK. +// Later we can add any logic here. +func (s *Server) httpHandler(w http.ResponseWriter, r *http.Request) { + // Make sure method matches + require.Equal(s.t, http.MethodPost, r.Method) + + var req rpcRequest + + // Decode request + raw, err := io.ReadAll(r.Body) + require.NoError(s.t, err) + require.NoError(s.t, json.Unmarshal(raw, &req), "unable to unmarshal request") + + // Process request + res := s.rpcHandler(req) + + // Encode response + response, err := json.Marshal(res) + require.NoError(s.t, err, "unable to marshal response") + + w.WriteHeader(http.StatusOK) + _, err = w.Write(response) + require.NoError(s.t, err, "unable to write response") + + s.t.Logf("%s RPC: incoming request: %+v; response: %+v", s.name, req, res) +} + +func (s *Server) rpcHandler(req rpcRequest) rpcResponse { + call, ok := s.handlers[req.Method] + if !ok { + return rpcResponse{Error: errors.New("method not found")} + } + + res, err := call(req.Params) + + return rpcResponse{Result: res, Error: err} +} diff --git a/zetaclient/testutils/testrpc/rpc_btc.go b/zetaclient/testutils/testrpc/rpc_btc.go new file mode 100644 index 0000000000..ec43242944 --- /dev/null +++ b/zetaclient/testutils/testrpc/rpc_btc.go @@ -0,0 +1,52 @@ +package testrpc + +import ( + "fmt" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/zeta-chain/zetacore/zetaclient/config" +) + +// BtcServer represents httptest for Bitcoin RPC. +type BtcServer struct { + *Server +} + +// NewBtcServer creates new BtcServer. +func NewBtcServer(t *testing.T) (*BtcServer, config.BTCConfig) { + rpc, rpcURL := New(t, "bitcoin") + + host, err := formatBitcoinRPCHost(rpcURL) + require.NoError(t, err) + + cfg := config.BTCConfig{ + RPCUsername: "btc-user", + RPCPassword: "btc-password", + RPCHost: host, + RPCParams: "", + } + + rpc.On("ping", func(_ []any) (any, error) { + return nil, nil + }) + + return &BtcServer{rpc}, cfg +} + +func (s *BtcServer) SetBlockCount(count int) { + s.On("getblockcount", func(_ []any) (any, error) { + return count, nil + }) +} + +func formatBitcoinRPCHost(serverURL string) (string, error) { + u, err := url.Parse(serverURL) + if err != nil { + return "", err + } + + return fmt.Sprintf("%s:%s", u.Hostname(), u.Port()), nil +} diff --git a/zetaclient/testutils/testrpc/rpc_evm.go b/zetaclient/testutils/testrpc/rpc_evm.go new file mode 100644 index 0000000000..a03e748245 --- /dev/null +++ b/zetaclient/testutils/testrpc/rpc_evm.go @@ -0,0 +1,29 @@ +package testrpc + +import ( + "fmt" + "testing" +) + +// EVMServer represents httptest for EVM RPC. +type EVMServer struct { + *Server + Endpoint string +} + +// NewEVMServer creates a new EVMServer. +func NewEVMServer(t *testing.T) *EVMServer { + rpc, endpoint := New(t, "EVM") + + return &EVMServer{Server: rpc, Endpoint: endpoint} +} + +func (s *EVMServer) SetBlockNumber(n int) { + s.On("eth_blockNumber", func(_ []any) (any, error) { + return hex(n), nil + }) +} + +func hex(v any) string { + return fmt.Sprintf("0x%x", v) +} diff --git a/zetaclient/testutils/testrpc/rpc_solana.go b/zetaclient/testutils/testrpc/rpc_solana.go new file mode 100644 index 0000000000..7487f7fb6a --- /dev/null +++ b/zetaclient/testutils/testrpc/rpc_solana.go @@ -0,0 +1,21 @@ +package testrpc + +import ( + "testing" + + "github.com/zeta-chain/zetacore/zetaclient/config" +) + +// SolanaServer represents httptest for SOL RPC. +type SolanaServer struct { + *Server +} + +// NewSolanaServer creates a new SolanaServer. +func NewSolanaServer(t *testing.T) (*SolanaServer, config.SolanaConfig) { + rpc, endpoint := New(t, "Solana") + + cfg := config.SolanaConfig{Endpoint: endpoint} + + return &SolanaServer{Server: rpc}, cfg +}