-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Package db represents API for database operations. | ||
package db | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"strings" | ||
|
||
"github.com/pkg/errors" | ||
"gorm.io/driver/sqlite" | ||
"gorm.io/gorm" | ||
"gorm.io/gorm/logger" | ||
|
||
"github.com/zeta-chain/zetacore/zetaclient/types" | ||
) | ||
|
||
// SqliteInMemory is a special string to use in-memory database. | ||
// @see https://www.sqlite.org/inmemorydb.html | ||
const SqliteInMemory = ":memory:" | ||
|
||
// read/write/execute for user | ||
// read/write for group | ||
const dirCreationMode = 0o750 | ||
|
||
var ( | ||
defaultGormConfig = &gorm.Config{ | ||
Logger: logger.Default.LogMode(logger.Silent), | ||
} | ||
migrationEntities = []any{ | ||
&types.LastBlockSQLType{}, | ||
&types.TransactionSQLType{}, | ||
&types.ReceiptSQLType{}, | ||
} | ||
) | ||
|
||
// DB database. | ||
type DB struct { | ||
db *gorm.DB | ||
} | ||
|
||
// NewFromSqlite creates a new instance of DB based on SQLite database. | ||
func NewFromSqlite(directory, dbName string, migrate bool) (*DB, error) { | ||
path, err := ensurePath(directory, dbName) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "unable to ensure database path") | ||
} | ||
|
||
return New(sqlite.Open(path), migrate) | ||
} | ||
|
||
// New creates a new instance of DB. | ||
func New(dial gorm.Dialector, migrate bool) (*DB, error) { | ||
// open db | ||
db, err := gorm.Open(dial, defaultGormConfig) | ||
if err != nil { | ||
return nil, errors.Wrap(err, "unable to open gorm database") | ||
} | ||
|
||
if migrate { | ||
if err := db.AutoMigrate(migrationEntities...); err != nil { | ||
return nil, errors.Wrap(err, "unable to migrate database") | ||
} | ||
} | ||
|
||
return &DB{db}, nil | ||
} | ||
|
||
// Client returns the underlying gorm database. | ||
func (db *DB) Client() *gorm.DB { | ||
return db.db | ||
} | ||
|
||
// Close closes the database. | ||
func (db *DB) Close() error { | ||
sqlDB, err := db.db.DB() | ||
if err != nil { | ||
return errors.Wrap(err, "unable to get underlying sql.DB") | ||
} | ||
|
||
if err := sqlDB.Close(); err != nil { | ||
return errors.Wrap(err, "unable to close sql.DB") | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func ensurePath(directory, dbName string) (string, error) { | ||
// pass in-memory database as is | ||
if strings.Contains(directory, SqliteInMemory) { | ||
return directory, nil | ||
} | ||
|
||
_, err := os.Stat(directory) | ||
switch { | ||
case os.IsNotExist(err): | ||
if err := os.MkdirAll(directory, dirCreationMode); err != nil { | ||
return "", errors.Wrapf(err, "unable to create database path %q", directory) | ||
} | ||
case err != nil: | ||
return "", errors.Wrap(err, "unable to check database path") | ||
} | ||
|
||
return fmt.Sprintf("%s/%s", directory, dbName), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package db | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"github.com/zeta-chain/zetacore/zetaclient/types" | ||
) | ||
|
||
func TestNew(t *testing.T) { | ||
t.Run("in memory", func(t *testing.T) { | ||
// ARRANGE | ||
// Given a database | ||
db, err := NewFromSqlite(SqliteInMemory, "", true) | ||
require.NoError(t, err) | ||
require.NotNil(t, db) | ||
|
||
// ACT | ||
runSampleSetGetTest(t, db) | ||
|
||
// Close the database | ||
assert.NoError(t, db.Close()) | ||
}) | ||
|
||
t.Run("file based", func(t *testing.T) { | ||
// ARRANGE | ||
// Given a tmp path | ||
directory, dbName := t.TempDir(), "test.db" | ||
|
||
// Given a database | ||
db, err := NewFromSqlite(directory, dbName, true) | ||
require.NoError(t, err) | ||
require.NotNil(t, db) | ||
|
||
// Check that the database file exists | ||
assert.FileExists(t, directory+"/"+dbName) | ||
|
||
// ACT | ||
runSampleSetGetTest(t, db) | ||
|
||
// Close the database | ||
assert.NoError(t, db.Close()) | ||
}) | ||
} | ||
|
||
func runSampleSetGetTest(t *testing.T, db *DB) { | ||
// Given a dummy sql type | ||
entity := types.ToLastBlockSQLType(444) | ||
|
||
// ACT #1 | ||
// Create entity | ||
result := db.Client().Create(&entity) | ||
|
||
// ASSERT | ||
assert.NoError(t, result.Error) | ||
|
||
// ACT #2 | ||
// Fetch entity | ||
var entity2 types.LastBlockSQLType | ||
|
||
result = db.Client().First(&entity2) | ||
|
||
// ASSERT | ||
assert.NoError(t, result.Error) | ||
assert.Equal(t, entity.Num, entity2.Num) | ||
} |