From ece0db539d71f1d9906edbb668f2fa258091668e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Men=C3=A9ndez?= Date: Thu, 11 Apr 2024 14:27:18 +0200 Subject: [PATCH] iterator interface --- scanner/providers/web3/web3_client.go | 86 +++++++++---------- .../web3/{web3_endpoint.go => web3_iter.go} | 53 ++++++------ scanner/providers/web3/web3_pool.go | 34 ++++---- 3 files changed, 86 insertions(+), 87 deletions(-) rename scanner/providers/web3/{web3_endpoint.go => web3_iter.go} (69%) diff --git a/scanner/providers/web3/web3_client.go b/scanner/providers/web3/web3_client.go index a5573634..ca0ca6f5 100644 --- a/scanner/providers/web3/web3_client.go +++ b/scanner/providers/web3/web3_client.go @@ -32,9 +32,9 @@ type Client struct { // EthClient method returns the ethclient.Client for the chainID of the Client // instance. It returns an error if the chainID is not found in the pool. func (c *Client) EthClient() (*ethclient.Client, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } return endpoint.client, nil } @@ -44,9 +44,9 @@ func (c *Client) EthClient() (*ethclient.Client, error) { // found in the pool or if the method fails. Required by the bind.ContractBackend // interface. func (c *Client) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -65,9 +65,9 @@ func (c *Client) CodeAt(ctx context.Context, account common.Address, blockNumber // not found in the pool or if the method fails. Required by the // bind.ContractBackend interface. func (c *Client) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -86,9 +86,9 @@ func (c *Client) CallContract(ctx context.Context, call ethereum.CallMsg, blockN // found in the pool or if the method fails. Required by the bind.ContractBackend // interface. func (c *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return 0, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return 0, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -107,9 +107,9 @@ func (c *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64, // found in the pool or if the method fails. Required by the bind.ContractBackend // interface. func (c *Client) FilterLogs(ctx context.Context, query ethereum.FilterQuery) ([]types.Log, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -128,9 +128,9 @@ func (c *Client) FilterLogs(ctx context.Context, query ethereum.FilterQuery) ([] // not found in the pool or if the method fails. Required by the // bind.ContractBackend interface. func (c *Client) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -149,9 +149,9 @@ func (c *Client) HeaderByNumber(ctx context.Context, number *big.Int) (*types.He // if the chainID is not found in the pool or if the method fails. Required by // the bind.ContractBackend interface. func (c *Client) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return 0, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return 0, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -170,9 +170,9 @@ func (c *Client) PendingNonceAt(ctx context.Context, account common.Address) (ui // if the chainID is not found in the pool or if the method fails. Required by // the bind.ContractBackend interface. func (c *Client) SuggestGasPrice(ctx context.Context) (*big.Int, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -191,12 +191,12 @@ func (c *Client) SuggestGasPrice(ctx context.Context) (*big.Int, error) { // not found in the pool or if the method fails. Required by the // bind.ContractBackend interface. func (c *Client) SendTransaction(ctx context.Context, tx *types.Transaction) error { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error - _, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { + _, err = c.retryAndCheckErr(endpoint.URI, func() (any, error) { internalCtx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() return nil, endpoint.client.SendTransaction(internalCtx, tx) @@ -209,9 +209,9 @@ func (c *Client) SendTransaction(ctx context.Context, tx *types.Transaction) err // not found in the pool or if the method fails. Required by the // bind.ContractBackend interface. func (c *Client) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -232,9 +232,9 @@ func (c *Client) PendingCodeAt(ctx context.Context, account common.Address) ([]b func (c *Client) SubscribeFilterLogs(ctx context.Context, query ethereum.FilterQuery, ch chan<- types.Log, ) (ethereum.Subscription, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -253,9 +253,9 @@ func (c *Client) SubscribeFilterLogs(ctx context.Context, // if the chainID is not found in the pool or if the method fails. Required by // the bind.ContractBackend interface. func (c *Client) SuggestGasTipCap(ctx context.Context) (*big.Int, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -274,9 +274,9 @@ func (c *Client) SuggestGasTipCap(ctx context.Context) (*big.Int, error) { // found in the pool or if the method fails. This method is required by internal // logic, it is not required by the bind.ContractBackend interface. func (c *Client) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { @@ -295,9 +295,9 @@ func (c *Client) BalanceAt(ctx context.Context, account common.Address, blockNum // found in the pool or if the method fails. This method is required by internal // logic, it is not required by the bind.ContractBackend interface. func (c *Client) BlockNumber(ctx context.Context) (uint64, error) { - endpoint, ok := c.w3p.EndpointByChainID(c.chainID) - if !ok { - return 0, fmt.Errorf("error getting endpoint for chainID %d", c.chainID) + endpoint, err := c.w3p.Endpoint(c.chainID) + if err != nil { + return 0, fmt.Errorf("error getting endpoint for chainID %d: %w", c.chainID, err) } // retry the method in case of failure and get final result and error res, err := c.retryAndCheckErr(endpoint.URI, func() (any, error) { diff --git a/scanner/providers/web3/web3_endpoint.go b/scanner/providers/web3/web3_iter.go similarity index 69% rename from scanner/providers/web3/web3_endpoint.go rename to scanner/providers/web3/web3_iter.go index 84accc21..286c9708 100644 --- a/scanner/providers/web3/web3_endpoint.go +++ b/scanner/providers/web3/web3_iter.go @@ -1,6 +1,7 @@ package web3 import ( + "fmt" "sync" "sync/atomic" @@ -18,22 +19,22 @@ type Web3Endpoint struct { client *ethclient.Client } -// Web3EndpointPool struct is a pool of Web3Endpoint that allows to get the next +// Web3Iterator struct is a pool of Web3Endpoint that allows to get the next // available endpoint in a round-robin fashion. It also allows to disable an // endpoint if it fails. It allows to manage multiple endpoints safely. -type Web3EndpointPool struct { +type Web3Iterator struct { nextIndex atomic.Uint32 available []*Web3Endpoint disabled []*Web3Endpoint mtx sync.Mutex } -// NewWeb3EndpointPool creates a new Web3EndpointPool with the given endpoints. -func newWeb3EndpointPool(endpoints ...*Web3Endpoint) *Web3EndpointPool { +// NewWeb3Iterator creates a new Web3Iterator with the given endpoints. +func NewWeb3Iterator(endpoints ...*Web3Endpoint) *Web3Iterator { if endpoints == nil { endpoints = make([]*Web3Endpoint, 0) } - return &Web3EndpointPool{ + return &Web3Iterator{ available: endpoints, disabled: make([]*Web3Endpoint, 0), } @@ -41,34 +42,26 @@ func newWeb3EndpointPool(endpoints ...*Web3Endpoint) *Web3EndpointPool { // Add adds a new endpoint to the pool, making it available for the next // requests. -func (w3pp *Web3EndpointPool) add(endpoint *Web3Endpoint) { +func (w3pp *Web3Iterator) Add(endpoint *Web3Endpoint) { w3pp.mtx.Lock() defer w3pp.mtx.Unlock() w3pp.available = append(w3pp.available, endpoint) } -// Next returns the next available endpoint in a round-robin fashion. If there -// are no endpoints, it will return nil. If there are no available endpoints, it -// will reset the disabled endpoints and return the first available endpoint. -func (w3pp *Web3EndpointPool) next() *Web3Endpoint { +// Next returns the next available endpoint in a round-robin fashion. If +// there are no endpoints, it will return an error. If there are no available +// endpoints, it will reset the disabled endpoints and return the first +// available endpoint. +func (w3pp *Web3Iterator) Next() (*Web3Endpoint, error) { w3pp.mtx.Lock() defer w3pp.mtx.Unlock() - // check if there is any available endpoint - l := len(w3pp.available) + l := uint32(len(w3pp.available)) if l == 0 { - // reset the next index and move the disabled endpoints to the available - w3pp.nextIndex.Store(0) - w3pp.available = append(w3pp.available, w3pp.disabled...) - w3pp.disabled = make([]*Web3Endpoint, 0) - // if continue to have no available endpoints, return nil - if len(w3pp.available) == 0 { - return nil - } - return w3pp.available[0] + return nil, fmt.Errorf("no available endpoints") } // get the current next index and endpoint currentIndex := w3pp.nextIndex.Load() - if int(currentIndex) >= l { + if currentIndex >= l { // if the current index is out of bounds, reset it to the first one currentIndex = 0 } @@ -80,17 +73,17 @@ func (w3pp *Web3EndpointPool) next() *Web3Endpoint { } // calculate the following next endpoint index based on the current one nextIndex := currentIndex + 1 - if int(nextIndex) >= l { + if nextIndex >= l { nextIndex = 0 } // update the next index and return the current endpoint w3pp.nextIndex.Store(nextIndex) - return currentEndpoint + return currentEndpoint, nil } -// disable method disables an endpoint, moving it from the available list to the +// Disable method disables an endpoint, moving it from the available list to the // the disabled list. -func (w3pp *Web3EndpointPool) disable(uri string) { +func (w3pp *Web3Iterator) Disable(uri string) { w3pp.mtx.Lock() defer w3pp.mtx.Unlock() // remove the endpoint from the available list @@ -100,4 +93,12 @@ func (w3pp *Web3EndpointPool) disable(uri string) { w3pp.disabled = append(w3pp.disabled, e) } } + // if there are no available endpoints, reset all the disabled ones to + // available ones and reset the next index to the first one + if l := len(w3pp.available); l == 0 { + // reset the next index and move the disabled endpoints to the available + w3pp.nextIndex.Store(0) + w3pp.available = append(w3pp.available, w3pp.disabled...) + w3pp.disabled = make([]*Web3Endpoint, 0) + } } diff --git a/scanner/providers/web3/web3_pool.go b/scanner/providers/web3/web3_pool.go index dcf098ac..afdc060a 100644 --- a/scanner/providers/web3/web3_pool.go +++ b/scanner/providers/web3/web3_pool.go @@ -29,7 +29,7 @@ import ( // It allows to support multiple endpoints for the same chainID and switch // between them looking for the available one. type Web3Pool struct { - endpoints map[uint64]*Web3EndpointPool + endpoints map[uint64]*Web3Iterator metadata []*Web3Endpoint } @@ -47,7 +47,7 @@ func NewWeb3Pool() (*Web3Pool, error) { return nil, fmt.Errorf("error decoding chains information from external source: %v", err) } return &Web3Pool{ - endpoints: make(map[uint64]*Web3EndpointPool), + endpoints: make(map[uint64]*Web3Iterator), metadata: chainsData, }, nil } @@ -91,9 +91,9 @@ func (nm *Web3Pool) AddEndpoint(uri string) error { client: client, } if _, ok := nm.endpoints[chainID]; !ok { - nm.endpoints[chainID] = newWeb3EndpointPool(endpoint) + nm.endpoints[chainID] = NewWeb3Iterator(endpoint) } else { - nm.endpoints[chainID].add(endpoint) + nm.endpoints[chainID].Add(endpoint) } return nil } @@ -103,32 +103,30 @@ func (nm *Web3Pool) AddEndpoint(uri string) error { // endpoints for the chainID where it was found. func (nm *Web3Pool) DelEndoint(uri string) { for _, endpoints := range nm.endpoints { - endpoints.disable(uri) + endpoints.Disable(uri) } } -// EndpointByChainID method returns the Web3Endpoint configured for the chainID +// Endpoint method returns the Web3Endpoint configured for the chainID // provided. It returns the first available endpoint. If no available endpoint -// is found, it resets the available flag for all, resets the next available to -// the first one and returns it. -func (nm *Web3Pool) EndpointByChainID(chainID uint64) (*Web3Endpoint, bool) { - next := nm.endpoints[chainID].next() - return next, next != nil +// is found, returns an error. +func (nm *Web3Pool) Endpoint(chainID uint64) (*Web3Endpoint, error) { + return nm.endpoints[chainID].Next() } // DisableEndpoint method sets the available flag to false for the URI provided // in the chainID provided. func (nm *Web3Pool) DisableEndpoint(chainID uint64, uri string) { if endpoints, ok := nm.endpoints[chainID]; ok { - endpoints.disable(uri) + endpoints.Disable(uri) } } -// GetClient method returns a new *Client instance for the chainID provided. +// Client method returns a new *Client instance for the chainID provided. // It returns an error if the endpoint is not found. func (nm *Web3Pool) Client(chainID uint64) (*Client, error) { - if _, ok := nm.EndpointByChainID(chainID); !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", chainID) + if _, err := nm.Endpoint(chainID); err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", chainID, err) } return &Client{w3p: nm, chainID: chainID}, nil } @@ -167,9 +165,9 @@ func (nm *Web3Pool) String() string { func (nm *Web3Pool) CurrentBlockNumbers(ctx context.Context) (map[uint64]uint64, error) { blockNumbers := make(map[uint64]uint64) for chainID := range nm.endpoints { - cli, ok := nm.EndpointByChainID(chainID) - if !ok { - return nil, fmt.Errorf("error getting endpoint for chainID %d", chainID) + cli, err := nm.Endpoint(chainID) + if err != nil { + return nil, fmt.Errorf("error getting endpoint for chainID %d: %w", chainID, err) } blockNumber, err := cli.client.BlockNumber(ctx) if err != nil {