Skip to content

Commit

Permalink
basic abstraction for endpoints pools by chain id
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasmenendez committed Apr 11, 2024
1 parent 5a35070 commit c0505c0
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 219 deletions.
2 changes: 1 addition & 1 deletion scanner/providers/farcaster/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (p *FarcasterProvider) Init(iconf any) error {
}
p.contracts.lastBlock.Store(uint64(lastBlock))
// init the web3 client and contracts
p.client, err = p.endpoints.GetClient(ChainID)
p.client, err = p.endpoints.Client(ChainID)
if err != nil {
return errors.Join(web3.ErrConnectingToWeb3Client, fmt.Errorf("[FARCASTER]: error getting web3 client: %w", err))
}
Expand Down
2 changes: 1 addition & 1 deletion scanner/providers/web3/erc20_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (p *ERC20HolderProvider) SetRef(iref any) error {
return fmt.Errorf("invalid ref type, it must be Web3ProviderRef")
}
var err error
p.client, err = p.endpoints.GetClient(ref.ChainID)
p.client, err = p.endpoints.Client(ref.ChainID)
if err != nil {
return fmt.Errorf("error getting web3 client for the given chainID: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion scanner/providers/web3/erc721_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (p *ERC721HolderProvider) SetRef(iref any) error {
return errors.New("invalid ref type, it must be Web3ProviderRef")
}
var err error
p.client, err = p.endpoints.GetClient(ref.ChainID)
p.client, err = p.endpoints.Client(ref.ChainID)
if err != nil {
return fmt.Errorf("error getting web3 client for the given chainID: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion scanner/providers/web3/erc777_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (p *ERC777HolderProvider) SetRef(iref any) error {
return errors.New("invalid ref type, it must be Web3ProviderRef")
}
var err error
p.client, err = p.endpoints.GetClient(ref.ChainID)
p.client, err = p.endpoints.Client(ref.ChainID)
if err != nil {
return fmt.Errorf("error getting web3 client for the given chainID: %w", err)
}
Expand Down
28 changes: 14 additions & 14 deletions scanner/providers/web3/web3_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Client struct {
// EthClient method returns the ethclient.Client for the chainID of the Client
// instance. It returns an error if the chainID is not found in the pool.
func (c *Client) EthClient() (*ethclient.Client, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -44,7 +44,7 @@ func (c *Client) EthClient() (*ethclient.Client, error) {
// found in the pool or if the method fails. Required by the bind.ContractBackend
// interface.
func (c *Client) CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -65,7 +65,7 @@ func (c *Client) CodeAt(ctx context.Context, account common.Address, blockNumber
// not found in the pool or if the method fails. Required by the
// bind.ContractBackend interface.
func (c *Client) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -86,7 +86,7 @@ func (c *Client) CallContract(ctx context.Context, call ethereum.CallMsg, blockN
// found in the pool or if the method fails. Required by the bind.ContractBackend
// interface.
func (c *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return 0, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -107,7 +107,7 @@ func (c *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64,
// found in the pool or if the method fails. Required by the bind.ContractBackend
// interface.
func (c *Client) FilterLogs(ctx context.Context, query ethereum.FilterQuery) ([]types.Log, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -128,7 +128,7 @@ func (c *Client) FilterLogs(ctx context.Context, query ethereum.FilterQuery) ([]
// not found in the pool or if the method fails. Required by the
// bind.ContractBackend interface.
func (c *Client) HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -149,7 +149,7 @@ func (c *Client) HeaderByNumber(ctx context.Context, number *big.Int) (*types.He
// if the chainID is not found in the pool or if the method fails. Required by
// the bind.ContractBackend interface.
func (c *Client) PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return 0, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -170,7 +170,7 @@ func (c *Client) PendingNonceAt(ctx context.Context, account common.Address) (ui
// if the chainID is not found in the pool or if the method fails. Required by
// the bind.ContractBackend interface.
func (c *Client) SuggestGasPrice(ctx context.Context) (*big.Int, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -191,7 +191,7 @@ func (c *Client) SuggestGasPrice(ctx context.Context) (*big.Int, error) {
// not found in the pool or if the method fails. Required by the
// bind.ContractBackend interface.
func (c *Client) SendTransaction(ctx context.Context, tx *types.Transaction) error {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -209,7 +209,7 @@ func (c *Client) SendTransaction(ctx context.Context, tx *types.Transaction) err
// not found in the pool or if the method fails. Required by the
// bind.ContractBackend interface.
func (c *Client) PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -232,7 +232,7 @@ func (c *Client) PendingCodeAt(ctx context.Context, account common.Address) ([]b
func (c *Client) SubscribeFilterLogs(ctx context.Context,
query ethereum.FilterQuery, ch chan<- types.Log,
) (ethereum.Subscription, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -253,7 +253,7 @@ func (c *Client) SubscribeFilterLogs(ctx context.Context,
// if the chainID is not found in the pool or if the method fails. Required by
// the bind.ContractBackend interface.
func (c *Client) SuggestGasTipCap(ctx context.Context) (*big.Int, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -274,7 +274,7 @@ func (c *Client) SuggestGasTipCap(ctx context.Context) (*big.Int, error) {
// found in the pool or if the method fails. This method is required by internal
// logic, it is not required by the bind.ContractBackend interface.
func (c *Client) BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return nil, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand All @@ -295,7 +295,7 @@ func (c *Client) BalanceAt(ctx context.Context, account common.Address, blockNum
// found in the pool or if the method fails. This method is required by internal
// logic, it is not required by the bind.ContractBackend interface.
func (c *Client) BlockNumber(ctx context.Context) (uint64, error) {
endpoint, ok := c.w3p.GetEndpoint(c.chainID)
endpoint, ok := c.w3p.EndpointByChainID(c.chainID)
if !ok {
return 0, fmt.Errorf("error getting endpoint for chainID %d", c.chainID)
}
Expand Down
103 changes: 103 additions & 0 deletions scanner/providers/web3/web3_endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package web3

import (
"sync"
"sync/atomic"

"github.com/ethereum/go-ethereum/ethclient"
)

// Web3Endpoint struct contains all the required information about a web3
// provider based on its URI. It includes its chain ID, its name (and shortName)
// and the URI.
type Web3Endpoint struct {
ChainID uint64 `json:"chainId"`
Name string `json:"name"`
ShortName string `json:"shortName"`
URI string
client *ethclient.Client
}

// Web3EndpointPool struct is a pool of Web3Endpoint that allows to get the next
// available endpoint in a round-robin fashion. It also allows to disable an
// endpoint if it fails. It allows to manage multiple endpoints safely.
type Web3EndpointPool struct {
nextIndex atomic.Uint32
available []*Web3Endpoint
disabled []*Web3Endpoint
mtx sync.Mutex
}

// NewWeb3EndpointPool creates a new Web3EndpointPool with the given endpoints.
func newWeb3EndpointPool(endpoints ...*Web3Endpoint) *Web3EndpointPool {
if endpoints == nil {
endpoints = make([]*Web3Endpoint, 0)
}
return &Web3EndpointPool{
available: endpoints,
disabled: make([]*Web3Endpoint, 0),
}
}

// Add adds a new endpoint to the pool, making it available for the next
// requests.
func (w3pp *Web3EndpointPool) add(endpoint *Web3Endpoint) {
w3pp.mtx.Lock()
defer w3pp.mtx.Unlock()
w3pp.available = append(w3pp.available, endpoint)
}

// Next returns the next available endpoint in a round-robin fashion. If there
// are no endpoints, it will return nil. If there are no available endpoints, it
// will reset the disabled endpoints and return the first available endpoint.
func (w3pp *Web3EndpointPool) next() *Web3Endpoint {
w3pp.mtx.Lock()
defer w3pp.mtx.Unlock()
// check if there is any available endpoint
l := len(w3pp.available)
if l == 0 {
// reset the next index and move the disabled endpoints to the available
w3pp.nextIndex.Store(0)
w3pp.available = append(w3pp.available, w3pp.disabled...)
w3pp.disabled = make([]*Web3Endpoint, 0)
// if continue to have no available endpoints, return nil
if len(w3pp.available) == 0 {
return nil
}
return w3pp.available[0]
}
// get the current next index and endpoint
currentIndex := w3pp.nextIndex.Load()
if int(currentIndex) >= l {
// if the current index is out of bounds, reset it to the first one
currentIndex = 0
}
currentEndpoint := w3pp.available[currentIndex]
if currentEndpoint == nil {
// if the current endpoint is nil, reset the index and get the first one
currentIndex = 0
currentEndpoint = w3pp.available[0]
}
// calculate the following next endpoint index based on the current one
nextIndex := currentIndex + 1
if int(nextIndex) >= l {
nextIndex = 0
}
// update the next index and return the current endpoint
w3pp.nextIndex.Store(nextIndex)
return currentEndpoint
}

// disable method disables an endpoint, moving it from the available list to the
// the disabled list.
func (w3pp *Web3EndpointPool) disable(uri string) {
w3pp.mtx.Lock()
defer w3pp.mtx.Unlock()
// remove the endpoint from the available list
for i, e := range w3pp.available {
if e.URI == uri {
w3pp.available = append(w3pp.available[:i], w3pp.available[i+1:]...)
w3pp.disabled = append(w3pp.disabled, e)
}
}
}
Loading

0 comments on commit c0505c0

Please sign in to comment.