From 87b2f41c62b79d3cdb748f0ac9ddb04aebd745b6 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Thu, 27 Jun 2024 15:06:05 +0200 Subject: [PATCH] Fix observer.OpenDB() in-mem sqlite creation (it was creating redundant directory) --- zetaclient/chains/base/observer.go | 42 ++++++++++++------- zetaclient/chains/base/observer_test.go | 2 +- .../chains/bitcoin/observer/observer_test.go | 2 +- .../chains/evm/observer/observer_test.go | 2 +- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/zetaclient/chains/base/observer.go b/zetaclient/chains/base/observer.go index 2999d8d21a..b6dfdef06d 100644 --- a/zetaclient/chains/base/observer.go +++ b/zetaclient/chains/base/observer.go @@ -290,27 +290,26 @@ func (ob *Observer) StopChannel() chan struct{} { // OpenDB open sql database in the given path. func (ob *Observer) OpenDB(dbPath string, dbName string) error { - // create db path if not exist - if _, err := os.Stat(dbPath); os.IsNotExist(err) { - err := os.MkdirAll(dbPath, os.ModePerm) - if err != nil { - return errors.Wrapf(err, "error creating db path: %s", dbPath) + var dial gorm.Dialector + + // SQLite in-mem db + if strings.Contains(dbPath, ":memory:") { + dial = sqlite.Open(dbPath) + } else { + if err := ensureDirectory(dbPath); err != nil { + return errors.Wrapf(err, "unable to ensure dbPath %q", dbPath) } - } - // use custom dbName or chain name if not provided - if dbName == "" { - dbName = ob.chain.ChainName.String() - } - path := fmt.Sprintf("%s/%s", dbPath, dbName) + // use custom dbName or chain name if not provided + if dbName == "" { + dbName = ob.chain.ChainName.String() + } - // use memory db if specified - if strings.Contains(dbPath, ":memory:") { - path = dbPath + dial = sqlite.Open(fmt.Sprintf("%s/%s", dbPath, dbName)) } // open db - db, err := gorm.Open(sqlite.Open(path), &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) + db, err := gorm.Open(dial, &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)}) if err != nil { return errors.Wrap(err, "error opening db") } @@ -320,11 +319,24 @@ func (ob *Observer) OpenDB(dbPath string, dbName string) error { if err != nil { return errors.Wrap(err, "error migrating db") } + ob.db = db return nil } +func ensureDirectory(path string) error { + _, err := os.Stat(path) + switch { + case os.IsNotExist(err): + return os.MkdirAll(path, os.ModePerm) + case err != nil: + return err + default: + return nil + } +} + // CloseDB close the database. func (ob *Observer) CloseDB() error { dbInst, err := ob.db.DB() diff --git a/zetaclient/chains/base/observer_test.go b/zetaclient/chains/base/observer_test.go index a04a48fcc3..5a382b1554 100644 --- a/zetaclient/chains/base/observer_test.go +++ b/zetaclient/chains/base/observer_test.go @@ -276,7 +276,7 @@ func TestOpenCloseDB(t *testing.T) { }) t.Run("should return error on invalid db path", func(t *testing.T) { err := ob.OpenDB("/invalid/123db", "") - require.ErrorContains(t, err, "error creating db path") + require.ErrorContains(t, err, "unable to ensure dbPath") }) } diff --git a/zetaclient/chains/bitcoin/observer/observer_test.go b/zetaclient/chains/bitcoin/observer/observer_test.go index c79209e3fc..05b98c0ef7 100644 --- a/zetaclient/chains/bitcoin/observer/observer_test.go +++ b/zetaclient/chains/bitcoin/observer/observer_test.go @@ -161,7 +161,7 @@ func Test_NewObserver(t *testing.T) { logger: base.Logger{}, ts: nil, fail: true, - message: "error creating db path", + message: "unable to ensure dbPath", }, } diff --git a/zetaclient/chains/evm/observer/observer_test.go b/zetaclient/chains/evm/observer/observer_test.go index f149d1bae2..f0dec5939d 100644 --- a/zetaclient/chains/evm/observer/observer_test.go +++ b/zetaclient/chains/evm/observer/observer_test.go @@ -153,7 +153,7 @@ func Test_NewObserver(t *testing.T) { logger: base.Logger{}, ts: nil, fail: true, - message: "error creating db path", + message: "unable to ensure dbPath", }, { name: "should fail if RPC call fails",