Skip to content

Commit

Permalink
Session: use parent DB if available in Transaction()
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Jan 29, 2024
1 parent d3a5620 commit 51eb4ac
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 2 additions & 0 deletions util/errors/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func NewSkip(reason any, skip int) error {
}

// Errorf is a shortcut for `errors.New(fmt.Errorf("format", args))`.
// Be careful when using this, this will result in losing the callers of
// the original error if one of the `args` is of type `*errors.Error`.
func Errorf(format string, args ...any) error {
return NewSkip(fmt.Errorf(format, args...), 3)
}
Expand Down
2 changes: 1 addition & 1 deletion util/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type dbKey struct{}
// The Gorm DB associated with this session is injected into the context as a value so `session.DB()`
// can be used to retrieve it.
func (s Gorm) Transaction(ctx context.Context, f func(context.Context) error) error {
tx := s.db.WithContext(ctx).Begin(s.TxOptions)
tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
if tx.Error != nil {
return errors.New(tx.Error)
}
Expand Down
27 changes: 27 additions & 0 deletions util/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils/tests"
"goyave.dev/goyave/v5/config"
"goyave.dev/goyave/v5/database"
Expand Down Expand Up @@ -122,12 +123,14 @@ func TestGormSession(t *testing.T) {

ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
tx, err := session.Begin(ctx)
tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB
assert.NoError(t, err)
assert.NotNil(t, tx)

subtx, err := session.Begin(tx.Context())
assert.NoError(t, err)
assert.Equal(t, "testvalue", subtx.(Gorm).db.Statement.Context.Value(testKey{})) // Parent context is kept
assert.Contains(t, subtx.(Gorm).db.Statement.Clauses, "testclause") // Parent DB is used
})

t.Run("Transaction", func(t *testing.T) {
Expand Down Expand Up @@ -155,6 +158,30 @@ func TestGormSession(t *testing.T) {
assert.False(t, committer.rolledback)
})

t.Run("Nested_Transaction", func(t *testing.T) {
db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
if !assert.NoError(t, err) {
return
}
committer := &testCommitter{}
db.Statement.ConnPool = committer
session := GORM(db, nil)

ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
tx, err := session.Begin(ctx)
tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB
assert.NoError(t, err)
assert.NotNil(t, tx)

err = session.Transaction(tx.Context(), func(ctx context.Context) error {
db := DB(ctx, nil)
assert.NotNil(t, db)
assert.Contains(t, db.Statement.Clauses, "testclause") // Parent DB is used
return nil
})
assert.NoError(t, err)
})

t.Run("TransactionError", func(t *testing.T) {
db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
if !assert.NoError(t, err) {
Expand Down

0 comments on commit 51eb4ac

Please sign in to comment.