Skip to content

Commit

Permalink
Add db package
Browse files Browse the repository at this point in the history
  • Loading branch information
swift1337 committed Jul 18, 2024
1 parent 232b48e commit 0a41312
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
104 changes: 104 additions & 0 deletions zetaclient/db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Package db represents API for database operations.
package db

import (
"fmt"
"os"
"strings"

"github.com/pkg/errors"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"

"github.com/zeta-chain/zetacore/zetaclient/types"
)

// SqliteInMemory is a special string to use in-memory database.
// @see https://www.sqlite.org/inmemorydb.html
const SqliteInMemory = ":memory:"

// read/write/execute for user
// read/write for group
const dirCreationMode = 0o750

var (
defaultGormConfig = &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
migrationEntities = []any{
&types.LastBlockSQLType{},
&types.TransactionSQLType{},
&types.ReceiptSQLType{},
}
)

// DB database.
type DB struct {
db *gorm.DB
}

// NewFromSqlite creates a new instance of DB based on SQLite database.
func NewFromSqlite(directory, dbName string, migrate bool) (*DB, error) {
path, err := ensurePath(directory, dbName)
if err != nil {
return nil, errors.Wrap(err, "unable to ensure database path")
}

return New(sqlite.Open(path), migrate)
}

// New creates a new instance of DB.
func New(dial gorm.Dialector, migrate bool) (*DB, error) {
// open db
db, err := gorm.Open(dial, defaultGormConfig)
if err != nil {
return nil, errors.Wrap(err, "unable to open gorm database")
}

if migrate {
if err := db.AutoMigrate(migrationEntities...); err != nil {
return nil, errors.Wrap(err, "unable to migrate database")
}
}

return &DB{db}, nil
}

// Client returns the underlying gorm database.
func (db *DB) Client() *gorm.DB {
return db.db
}

// Close closes the database.
func (db *DB) Close() error {
sqlDB, err := db.db.DB()
if err != nil {
return errors.Wrap(err, "unable to get underlying sql.DB")
}

if err := sqlDB.Close(); err != nil {
return errors.Wrap(err, "unable to close sql.DB")
}

return nil
}

func ensurePath(directory, dbName string) (string, error) {
// pass in-memory database as is
if strings.Contains(directory, SqliteInMemory) {
return directory, nil
}

_, err := os.Stat(directory)
switch {
case os.IsNotExist(err):
if err := os.MkdirAll(directory, dirCreationMode); err != nil {
return "", errors.Wrapf(err, "unable to create database path %q", directory)
}
case err != nil:
return "", errors.Wrap(err, "unable to check database path")
}

return fmt.Sprintf("%s/%s", directory, dbName), nil
}
67 changes: 67 additions & 0 deletions zetaclient/db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package db

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeta-chain/zetacore/zetaclient/types"
)

func TestNew(t *testing.T) {
t.Run("in memory", func(t *testing.T) {
// ARRANGE
// Given a database
db, err := NewFromSqlite(SqliteInMemory, "", true)
require.NoError(t, err)
require.NotNil(t, db)

// ACT
runSampleSetGetTest(t, db)

// Close the database
assert.NoError(t, db.Close())
})

t.Run("file based", func(t *testing.T) {
// ARRANGE
// Given a tmp path
directory, dbName := t.TempDir(), "test.db"

// Given a database
db, err := NewFromSqlite(directory, dbName, true)
require.NoError(t, err)
require.NotNil(t, db)

// Check that the database file exists
assert.FileExists(t, directory+"/"+dbName)

// ACT
runSampleSetGetTest(t, db)

// Close the database
assert.NoError(t, db.Close())
})
}

func runSampleSetGetTest(t *testing.T, db *DB) {
// Given a dummy sql type
entity := types.ToLastBlockSQLType(444)

// ACT #1
// Create entity
result := db.Client().Create(&entity)

// ASSERT
assert.NoError(t, result.Error)

// ACT #2
// Fetch entity
var entity2 types.LastBlockSQLType

result = db.Client().First(&entity2)

// ASSERT
assert.NoError(t, result.Error)
assert.Equal(t, entity.Num, entity2.Num)
}

0 comments on commit 0a41312

Please sign in to comment.