Skip to content

Commit

Permalink
Session: fix nested transactions for Gorm's TxCommitter
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Jun 13, 2024
1 parent 616cad9 commit 44a81c1
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 36 deletions.
74 changes: 69 additions & 5 deletions util/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package session
import (
"context"
"database/sql"
"fmt"

"gorm.io/gorm"
"goyave.dev/goyave/v5/util/errors"
Expand Down Expand Up @@ -44,6 +45,7 @@ type Gorm struct {
db *gorm.DB
TxOptions *sql.TxOptions
ctx context.Context
savepoint string // Savepoint for manual nested transactions
}

// GORM create a new root session for Gorm.
Expand All @@ -60,10 +62,16 @@ func GORM(db *gorm.DB, opt *sql.TxOptions) Gorm {
// The returned session has manual controls. Make sure a call to `Rollback()` or `Commit()`
// is executed before the session is expired (eligible for garbage collection).
// The Gorm DB associated with this session is injected as a value into the new session's context.
// If a Gorm DB is found in the given context, it will be used instead of this Session's DB, allowing for
// nested transactions.
//
// If the newly created session is nested, a savepoint is generated instead. Calls to the returned
// session's `Rollback()` will rollback to this savepoint.
// This behavior is disabled if gorm config `DisableNestedTransaction` is set to `true`.
func (s Gorm) Begin(ctx context.Context) (Session, error) {
tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
db := DB(ctx, s.db).WithContext(ctx)
if _, ok := db.Statement.ConnPool.(gorm.TxCommitter); ok {
return s.nestedBegin(db)
}
tx := db.Begin(s.TxOptions)
if tx.Error != nil {
return nil, errors.NewSkip(tx.Error, 3)
}
Expand All @@ -74,13 +82,43 @@ func (s Gorm) Begin(ctx context.Context) (Session, error) {
}, nil
}

func (s Gorm) nestedBegin(db *gorm.DB) (Session, error) {
nestedSession := Gorm{
ctx: context.WithValue(db.Statement.Context, dbKey{}, db),
TxOptions: s.TxOptions,
db: db,
}

nestedSession.savepoint = fmt.Sprintf("sp%p", nestedSession.ctx)
if !db.DisableNestedTransaction {
err := errors.NewSkip(db.SavePoint(nestedSession.savepoint).Error, 3)
if err != nil {
return nil, err
}
}
return nestedSession, nil
}

// Rollback the changes in the transaction. This action is final.
//
// If the session is nested, rolls back to the session's savepoint.
func (s Gorm) Rollback() error {
if s.savepoint != "" {
if s.db.DisableNestedTransaction {
return nil
}
return errors.NewSkip(s.db.RollbackTo(s.savepoint).Error, 3)
}
return errors.NewSkip(s.db.Rollback().Error, 3)
}

// Commit the changes in the transaction. This action is final.
//
// If the session is nested, calling Rollback() is a no-op.
func (s Gorm) Commit() error {
if s.savepoint != "" {
return nil
}
return errors.NewSkip(s.db.Commit().Error, 3)
}

Expand All @@ -100,12 +138,17 @@ type dbKey struct{}
// The Gorm DB associated with this session is injected into the context as a value so `session.DB()`
// can be used to retrieve it.
func (s Gorm) Transaction(ctx context.Context, f func(context.Context) error) error {
tx := DB(ctx, s.db).WithContext(ctx).Begin(s.TxOptions)
tx := DB(ctx, s.db).WithContext(ctx)
if _, ok := tx.Statement.ConnPool.(gorm.TxCommitter); ok {
return s.nestedTransaction(tx, f)
}

tx = tx.Begin(s.TxOptions)
if tx.Error != nil {
return errors.New(tx.Error)
}
c := context.WithValue(ctx, dbKey{}, tx)
err := f(c)
err := errors.New(f(c))
if err != nil {
tx.Rollback()
return errors.New(err)
Expand All @@ -117,6 +160,27 @@ func (s Gorm) Transaction(ctx context.Context, f func(context.Context) error) er
return nil
}

func (s Gorm) nestedTransaction(tx *gorm.DB, f func(context.Context) error) error {
panicked := true
savepoint := fmt.Sprintf("sp%p", f)
if !tx.DisableNestedTransaction {
err := tx.SavePoint(savepoint).Error
if err != nil {
return errors.New(err)
}
}
c := context.WithValue(tx.Statement.Context, dbKey{}, tx)
var err error
defer func() {
if !tx.DisableNestedTransaction && (panicked || err != nil) {
tx.RollbackTo(savepoint)
}
}()
err = errors.New(f(c))
panicked = false
return err
}

// DB returns the Gorm instance stored in the given context. Returns the given fallback
// if no Gorm DB could be found in the context.
func DB(ctx context.Context, fallback *gorm.DB) *gorm.DB {
Expand Down
Loading

0 comments on commit 44a81c1

Please sign in to comment.