Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TxWithRetries and test cases #109

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions internal/db/error.go
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the methods here looks very similar, they all try to check the cmdErr code.
shall we have a private generic method that does that and compare the code instead? then this generic method can be called by IsWriteConflictError etc by passing in the 122 code into the method

Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package db

import (
"errors"
"log"

"go.mongodb.org/mongo-driver/mongo"
)

// DuplicateKeyError is an error type for duplicate key errors
type DuplicateKeyError struct {
Key string
Expand Down Expand Up @@ -43,3 +50,44 @@ func IsNotFoundError(err error) bool {
_, ok := err.(*NotFoundError)
return ok
}

// Error code references: https://www.mongodb.com/docs/manual/reference/error-codes/
func IsWriteConflictError(err error) bool {
if err == nil {
log.Println("Error is nil, cannot be a write conflict")
return false
}

var cmdErr *mongo.CommandError
if errors.As(err, &cmdErr) {
if cmdErr == nil {
log.Println("Error is not a CommandError, cannot be a write conflict")
return false
}
log.Println("Checking for write conflict error, code received:", cmdErr.Code)
return cmdErr.Code == 112
}

log.Println("Error does not conform to CommandError")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we using "github.com/rs/zerolog/log" in the service. i'm not sure if using log package directly would be a good idea.

Copy link
Collaborator

@jrwbabylonlab jrwbabylonlab May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just adding on top of this. i don't think we actually need any logs in this file. it's the calling methods responsibility to log errors.
That means u can simply remove all logs in this file

return false
}

func IsTransactionAbortedError(err error) bool {
if err == nil {
log.Println("Error is nil, cannot be a transaction aborted")
return false
}

var cmdErr *mongo.CommandError
if errors.As(err, &cmdErr) {
if cmdErr == nil {
log.Println("Error is not a CommandError, cannot be a transaction aborted")
return false
}
log.Println("Checking for transaction aborted error, code received:", cmdErr.Code)
return cmdErr.Code == 251
}

log.Println("Error does not conform to CommandError")
return false
}
9 changes: 9 additions & 0 deletions internal/db/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (

"github.com/babylonchain/staking-api-service/internal/db/model"
"github.com/babylonchain/staking-api-service/internal/types"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

type DBClient interface {
Expand Down Expand Up @@ -54,3 +56,10 @@ type DBClient interface {
) error
FindTopStakersByTvl(ctx context.Context, paginationToken string) (*DbResultMap[model.StakerStatsDocument], error)
}
type DBSession interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remind me why we need to define this here?
I'm thinking twice around this, feel like we shall not expose internal implementation of the db client here as the interface here are used for defining API contract, it's not mongodb specific.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was trying to make it an interface so that the function could consume both test db and mongo db

EndSession(ctx context.Context)
WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*options.TransactionOptions) (interface{}, error)
}
type DBTransactionClient interface {
StartSession(opts ...*options.SessionOptions) (DBSession, error)
}
133 changes: 61 additions & 72 deletions internal/db/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,25 @@ func (db *Database) GetOrCreateStatsLock(
func (db *Database) IncrementOverallStats(
ctx context.Context, stakingTxHashHex, stakerPkHex string, amount uint64,
) error {
overallStatsClient := db.Client.Database(db.DbName).Collection(model.OverallStatsCollection)
stakerStatsClient := db.Client.Database(db.DbName).Collection(model.StakerStatsCollection)
// Define the work to be done in the transaction
transactionWork := func(sessCtx mongo.SessionContext) (interface{}, error) {
overallStatsClient := db.Client.Database(db.DbName).Collection(model.OverallStatsCollection)
stakerStatsClient := db.Client.Database(db.DbName).Collection(model.StakerStatsCollection)

err := db.updateStatsLockByFieldName(sessCtx, stakingTxHashHex, types.Active.ToString(), "overall_stats")
if err != nil {
return nil, err
}

// Start a session
session, sessionErr := db.Client.StartSession()
if sessionErr != nil {
return sessionErr
// The order of the overall stats and staker stats update is important.
// The staker stats colleciton will need to be processed first to determine if the staker is new
// If the staker stats is the first delegation for the staker, we need to increment the total stakers
var stakerStats model.StakerStatsDocument
stakerStatsFilter := bson.M{"_id": stakerPkHex}
stakerErr := stakerStatsClient.FindOne(ctx, stakerStatsFilter).Decode(&stakerStats)
if stakerErr != nil {
return nil, stakerErr
}
defer session.EndSession(ctx)

upsertUpdate := bson.M{
"$inc": bson.M{
Expand All @@ -66,37 +76,26 @@ func (db *Database) IncrementOverallStats(
"total_delegations": 1,
},
}
// Define the work to be done in the transaction
transactionWork := func(sessCtx mongo.SessionContext) (interface{}, error) {
err := db.updateStatsLockByFieldName(sessCtx, stakingTxHashHex, types.Active.ToString(), "overall_stats")
if err != nil {
return nil, err
}

// The order of the overall stats and staker stats update is important.
// The staker stats colleciton will need to be processed first to determine if the staker is new
// If the staker stats is the first delegation for the staker, we need to increment the total stakers
var stakerStats model.StakerStatsDocument
stakerStatsFilter := bson.M{"_id": stakerPkHex}
stakerErr := stakerStatsClient.FindOne(ctx, stakerStatsFilter).Decode(&stakerStats)
if stakerErr != nil {
return nil, stakerErr
}
if stakerStats.TotalDelegations == 1 {
if stakerStats.TotalDelegations == 1 {
upsertUpdate["$inc"].(bson.M)["total_stakers"] = 1
}
}

upsertFilter := bson.M{"_id": db.generateOverallStatsId()}
upsertFilter := bson.M{"_id": db.generateOverallStatsId()}

_, err = overallStatsClient.UpdateOne(sessCtx, upsertFilter, upsertUpdate, options.Update().SetUpsert(true))
if err != nil {
return nil, err
}
return nil, nil
_, err = overallStatsClient.UpdateOne(sessCtx, upsertFilter, upsertUpdate, options.Update().SetUpsert(true))
if err != nil {
return nil, err
}
return nil, nil
}

// Execute the transaction
_, txErr := session.WithTransaction(ctx, transactionWork)
// Execute the transaction with retries
_, txErr := TxWithRetries(
ctx,
&dbTransactionClient{db.Client},
transactionWork,
)
if txErr != nil {
return txErr
}
Expand All @@ -110,23 +109,15 @@ func (db *Database) IncrementOverallStats(
func (db *Database) SubtractOverallStats(
ctx context.Context, stakingTxHashHex, stakerPkHex string, amount uint64,
) error {
upsertUpdate := bson.M{
"$inc": bson.M{
"active_tvl": -int64(amount),
"active_delegations": -1,
},
}
overallStatsClient := db.Client.Database(db.DbName).Collection(model.OverallStatsCollection)

// Start a session
session, sessionErr := db.Client.StartSession()
if sessionErr != nil {
return sessionErr
}
defer session.EndSession(ctx)

// Define the work to be done in the transaction
transactionWork := func(sessCtx mongo.SessionContext) (interface{}, error) {
upsertUpdate := bson.M{
"$inc": bson.M{
"active_tvl": -int64(amount),
"active_delegations": -1,
},
}
overallStatsClient := db.Client.Database(db.DbName).Collection(model.OverallStatsCollection)
err := db.updateStatsLockByFieldName(sessCtx, stakingTxHashHex, types.Unbonded.ToString(), "overall_stats")
if err != nil {
return nil, err
Expand All @@ -141,8 +132,12 @@ func (db *Database) SubtractOverallStats(
return nil, nil
}

// Execute the transaction
_, txErr := session.WithTransaction(ctx, transactionWork)
// Execute the transaction with retries
_, txErr := TxWithRetries(
ctx,
&dbTransactionClient{db.Client},
transactionWork,
)
if txErr != nil {
return txErr
}
Expand Down Expand Up @@ -284,16 +279,9 @@ func (db *Database) FindFinalityProviderStats(ctx context.Context, paginationTok
}

func (db *Database) updateFinalityProviderStats(ctx context.Context, state, stakingTxHashHex, fpPkHex string, upsertUpdate primitive.M) error {
client := db.Client.Database(db.DbName).Collection(model.FinalityProviderStatsCollection)

// Start a session
session, sessionErr := db.Client.StartSession()
if sessionErr != nil {
return sessionErr
}
defer session.EndSession(ctx)

transactionWork := func(sessCtx mongo.SessionContext) (interface{}, error) {
client := db.Client.Database(db.DbName).Collection(model.FinalityProviderStatsCollection)

err := db.updateStatsLockByFieldName(sessCtx, stakingTxHashHex, state, "finality_provider_stats")
if err != nil {
return nil, err
Expand All @@ -308,8 +296,12 @@ func (db *Database) updateFinalityProviderStats(ctx context.Context, state, stak
return nil, nil
}

// Execute the transaction
_, txErr := session.WithTransaction(ctx, transactionWork)
// Execute the transaction with retries
_, txErr := TxWithRetries(
ctx,
&dbTransactionClient{db.Client},
transactionWork,
)
if txErr != nil {
return txErr
}
Expand Down Expand Up @@ -347,17 +339,10 @@ func (db *Database) SubtractStakerStats(
return db.updateStakerStats(ctx, types.Unbonded.ToString(), stakingTxHashHex, stakerPkHex, upsertUpdate)
}

func (db *Database) updateStakerStats(ctx context.Context, state, stakingTxHashHex, stakerPkHex string, upsertUpdate primitive.M) error {
client := db.Client.Database(db.DbName).Collection(model.StakerStatsCollection)

// Start a session
session, sessionErr := db.Client.StartSession()
if sessionErr != nil {
return sessionErr
}
defer session.EndSession(ctx)

func (db *Database) updateStakerStats(ctx context.Context, state, stakingTxHashHex, stakerPkHex string, upsertUpdate primitive.M) error {
transactionWork := func(sessCtx mongo.SessionContext) (interface{}, error) {
client := db.Client.Database(db.DbName).Collection(model.StakerStatsCollection)

err := db.updateStatsLockByFieldName(sessCtx, stakingTxHashHex, state, "staker_stats")
if err != nil {
return nil, err
Expand All @@ -372,8 +357,12 @@ func (db *Database) updateStakerStats(ctx context.Context, state, stakingTxHashH
return nil, nil
}

// Execute the transaction
_, txErr := session.WithTransaction(ctx, transactionWork)
// Execute the transaction with retries
_, txErr := TxWithRetries(
ctx,
&dbTransactionClient{db.Client},
transactionWork,
)
return txErr
}

Expand Down
104 changes: 104 additions & 0 deletions internal/db/transactions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package db

import (
"context"
"log"
"time"

utils "github.com/babylonchain/staking-api-service/internal/utils"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

const (
DefaultMaxAttempts = 4 // max attempt INCLUDES the first execution
DefaultInitialBackoff = 100 * time.Millisecond
DefaultBackoffFactor = 2.0
)

type dbTransactionClient struct {
*mongo.Client
}

type dbSessionWrapper struct {
mongo.Session
}

func (c *dbTransactionClient) StartSession(opts...*options.SessionOptions) (DBSession, error) {
session, err := c.Client.StartSession(opts...)
if err!= nil {
return nil, err
}
return &dbSessionWrapper{session}, nil
}


func (s *dbSessionWrapper) EndSession(ctx context.Context) {
s.Session.EndSession(ctx)
}

func (s *dbSessionWrapper) WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts...*options.TransactionOptions) (interface{}, error) {
return s.Session.WithTransaction(ctx, fn, opts...)
}

func TxWithRetries(
ctx context.Context,
dbTransactionClient DBTransactionClient,
txnFunc func(sessCtx mongo.SessionContext) (interface{}, error),
) (interface{}, error) {
maxAttempts := DefaultMaxAttempts
Copy link
Collaborator

@jrwbabylonlab jrwbabylonlab May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxAttempts is not overwritten anywhere in this method, why not use the pre-defined variable directly in this case?
same as initialBackoff and backoffFactor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I can remove that

initialBackoff := DefaultInitialBackoff
backoffFactor := DefaultBackoffFactor

var (
result interface{}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be too many tabs here

err error
backoff = initialBackoff
)

for attempt := 1; attempt <= maxAttempts; attempt++ {
session, sessionErr := dbTransactionClient.StartSession();


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra space

if sessionErr != nil {
return nil, sessionErr
}


result, err = session.WithTransaction(ctx, txnFunc)
session.EndSession(ctx)

if err != nil {
if shouldRetry(err) && attempt < maxAttempts {
log.Printf("Attempt %d failed with retryable error: %v. Retrying after %v...", attempt, err, backoff)
utils.Sleep(backoff)
backoff *= time.Duration(backoffFactor)
continue
}
log.Printf("Attempt %d failed with non-retryable error: %v", attempt, err)
return nil, err
}
break
}
return result, nil
}

// Check for network-related, timeout errors, write conflicts or transaction aborted, which are generally transient should retry. Other errors such as duplicated keys or other non-specified errors should be considered non-retryable.
func shouldRetry(err error) bool {
if mongo.IsNetworkError(err) {
return true
}
if mongo.IsTimeout(err) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the network error and timeout error from mongo is already retry handled by the mongo driver implementation of WithTransaction

return true
}

if IsWriteConflictError(err) {
return true
}

if IsTransactionAbortedError(err) {
return true
}

return false
}
Loading