diff --git a/pkg/source/makerpsm/pool_tracker.go b/pkg/source/makerpsm/pool_tracker.go index 17de1f529..a0a34e2fa 100644 --- a/pkg/source/makerpsm/pool_tracker.go +++ b/pkg/source/makerpsm/pool_tracker.go @@ -11,6 +11,8 @@ import ( "github.com/KyberNetwork/kyberswap-dex-lib/pkg/entity" sourcePool "github.com/KyberNetwork/kyberswap-dex-lib/pkg/source/pool" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethclient/gethclient" ) type PoolTracker struct { @@ -27,10 +29,27 @@ func NewPoolTracker(cfg *Config, ethrpcClient *ethrpc.Client) *PoolTracker { } } +func (d *PoolTracker) GetNewPoolStateWithOverrides( + ctx context.Context, + p entity.Pool, + params sourcePool.GetNewPoolStateWithOverridesParams, +) (entity.Pool, error) { + return d.getNewPoolState(ctx, p, sourcePool.GetNewPoolStateParams{Logs: params.Logs}, params.Overrides) +} + func (d *PoolTracker) GetNewPoolState( + ctx context.Context, + p entity.Pool, + params sourcePool.GetNewPoolStateParams, +) (entity.Pool, error) { + return d.getNewPoolState(ctx, p, params, nil) +} + +func (d *PoolTracker) getNewPoolState( ctx context.Context, pool entity.Pool, _ sourcePool.GetNewPoolStateParams, + overrides map[common.Address]gethclient.OverrideAccount, ) (entity.Pool, error) { defer func(startTime time.Time) { logger. @@ -42,7 +61,7 @@ func (d *PoolTracker) GetNewPoolState( Info("finished GetNewPoolState") }(time.Now()) - psm, err := d.getPSM(ctx, pool.Address) + psm, err := d.getPSM(ctx, pool.Address, overrides) if err != nil { logger.WithFields(logger.Fields{ "dexID": d.cfg.DexID, @@ -81,8 +100,12 @@ func (d *PoolTracker) GetNewPoolState( return pool, nil } -func (d *PoolTracker) getPSM(ctx context.Context, address string) (*PSM, error) { - psm, err := d.psmReader.Read(ctx, address) +func (d *PoolTracker) getPSM( + ctx context.Context, + address string, + overrides map[common.Address]gethclient.OverrideAccount, +) (*PSM, error) { + psm, err := d.psmReader.Read(ctx, address, overrides) if err != nil { logger.WithFields(logger.Fields{ "dexID": d.cfg.DexID, @@ -91,7 +114,7 @@ func (d *PoolTracker) getPSM(ctx context.Context, address string) (*PSM, error) return nil, err } - vat, err := d.vatReader.Read(ctx, psm.VatAddress.String(), psm.ILK) + vat, err := d.vatReader.Read(ctx, psm.VatAddress.String(), psm.ILK, overrides) if err != nil { logger.WithFields(logger.Fields{ "dexID": d.cfg.DexID, diff --git a/pkg/source/makerpsm/psm_reader.go b/pkg/source/makerpsm/psm_reader.go index 40874250f..5af896121 100644 --- a/pkg/source/makerpsm/psm_reader.go +++ b/pkg/source/makerpsm/psm_reader.go @@ -6,6 +6,8 @@ import ( "github.com/KyberNetwork/ethrpc" "github.com/KyberNetwork/logger" "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethclient/gethclient" ) type PSMReader struct { @@ -20,7 +22,7 @@ func NewPSMReader(ethrpcClient *ethrpc.Client) *PSMReader { } } -func (r *PSMReader) Read(ctx context.Context, address string) (*PSM, error) { +func (r *PSMReader) Read(ctx context.Context, address string, overrides map[common.Address]gethclient.OverrideAccount) (*PSM, error) { var psm PSM req := r.ethrpcClient. @@ -51,6 +53,9 @@ func (r *PSMReader) Read(ctx context.Context, address string) (*PSM, error) { Params: nil, }, []interface{}{&psm.ILK}) + if overrides != nil { + req.SetOverrides(overrides) + } _, err := req.Aggregate() if err != nil { logger.WithFields(logger.Fields{ diff --git a/pkg/source/makerpsm/vat_reader.go b/pkg/source/makerpsm/vat_reader.go index daed7fd98..b1c228ead 100644 --- a/pkg/source/makerpsm/vat_reader.go +++ b/pkg/source/makerpsm/vat_reader.go @@ -6,6 +6,8 @@ import ( "github.com/KyberNetwork/ethrpc" "github.com/KyberNetwork/logger" "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethclient/gethclient" ) type VatReader struct { @@ -20,7 +22,7 @@ func NewVatReader(ethrpcClient *ethrpc.Client) *VatReader { } } -func (r *VatReader) Read(ctx context.Context, address string, ilk [32]byte) (*Vat, error) { +func (r *VatReader) Read(ctx context.Context, address string, ilk [32]byte, overrides map[common.Address]gethclient.OverrideAccount) (*Vat, error) { var vat Vat req := r.ethrpcClient. @@ -45,6 +47,9 @@ func (r *VatReader) Read(ctx context.Context, address string, ilk [32]byte) (*Va Params: []interface{}{ilk}, }, []interface{}{&vat.ILK}) + if overrides != nil { + req.SetOverrides(overrides) + } _, err := req.Aggregate() if err != nil { logger.WithFields(logger.Fields{