Skip to content

Commit

Permalink
Merge branch 'main' into PRT-solana-tokens-owner-verification
Browse files Browse the repository at this point in the history
  • Loading branch information
shleikes authored Sep 26, 2024
2 parents a79afde + 5e3e277 commit c642cb2
Show file tree
Hide file tree
Showing 26 changed files with 310 additions and 123 deletions.
60 changes: 56 additions & 4 deletions protocol/chainlib/consumer_websocket_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ package chainlib
import (
"context"
"strconv"
"sync/atomic"
"time"

gojson "github.com/goccy/go-json"
"github.com/goccy/go-json"
"github.com/gofiber/websocket/v2"
formatter "github.com/lavanet/lava/v3/ecosystem/cache/format"
"github.com/lavanet/lava/v3/protocol/common"
"github.com/lavanet/lava/v3/protocol/metrics"
"github.com/lavanet/lava/v3/utils"
"github.com/lavanet/lava/v3/utils/rand"
spectypes "github.com/lavanet/lava/v3/x/spec/types"
"github.com/tidwall/gjson"
)

var WebSocketRateLimit = -1 // rate limit requests per second on websocket connection

type ConsumerWebsocketManager struct {
websocketConn *websocket.Conn
rpcConsumerLogs *metrics.RPCConsumerLogs
Expand Down Expand Up @@ -67,6 +71,27 @@ func (cwm *ConsumerWebsocketManager) GetWebSocketConnectionUniqueId(dappId, user
return dappId + "__" + userIp + "__" + cwm.WebsocketConnectionUID
}

func (cwm *ConsumerWebsocketManager) handleRateLimitReached(inpData []byte) ([]byte, error) {
rateLimitError := common.JsonRpcRateLimitError
id := 0
result := gjson.GetBytes(inpData, "id")
switch result.Type {
case gjson.Number:
id = int(result.Int())
case gjson.String:
idParsed, err := strconv.Atoi(result.Raw)
if err == nil {
id = idParsed
}
}
rateLimitError.Id = id
bytesRateLimitError, err := json.Marshal(rateLimitError)
if err != nil {
return []byte{}, utils.LavaFormatError("failed marshalling jsonrpc rate limit error", err)
}
return bytesRateLimitError, nil
}

func (cwm *ConsumerWebsocketManager) ListenToMessages() {
var (
messageType int
Expand Down Expand Up @@ -110,6 +135,24 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
}
}()

// rate limit routine
requestsPerSecond := &atomic.Uint64{}
go func() {
if WebSocketRateLimit <= 0 {
return
}
ticker := time.NewTicker(time.Second) // rate limit per second.
defer ticker.Stop()
for {
select {
case <-webSocketCtx.Done():
return
case <-ticker.C:
requestsPerSecond.Store(0)
}
}
}()

for {
startTime := time.Now()
msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int
Expand All @@ -125,6 +168,15 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
break
}

// Check rate limit is met
if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) {
rateLimitResponse, err := cwm.handleRateLimitReached(msg)
if err == nil {
websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse}
}
continue
}

dappID, ok := websocketConn.Locals("dapp-id").(string)
if !ok {
// Log and remove the analyze
Expand Down Expand Up @@ -160,14 +212,14 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
continue
}

// check whether its a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow.
// check whether it's a normal relay / unsubscribe / unsubscribe_all otherwise its a subscription flow.
if !IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_SUBSCRIBE) {
if IsFunctionTagOfType(protocolMessage, spectypes.FUNCTION_TAG_UNSUBSCRIBE) {
err := cwm.consumerWsSubscriptionManager.Unsubscribe(webSocketCtx, protocolMessage, dappID, userIp, cwm.WebsocketConnectionUID, metricsData)
if err != nil {
utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx))
if err == common.SubscriptionNotFoundError {
msgData, err := gojson.Marshal(common.JsonRpcSubscriptionNotFoundError)
msgData, err := json.Marshal(common.JsonRpcSubscriptionNotFoundError)
if err != nil {
continue
}
Expand Down Expand Up @@ -224,7 +276,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {

// Handle the case when the error is a method not found error
if common.APINotSupportedError.Is(err) {
msgData, err := gojson.Marshal(common.JsonRpcMethodNotFoundError)
msgData, err := json.Marshal(common.JsonRpcMethodNotFoundError)
if err != nil {
continue
}
Expand Down
11 changes: 10 additions & 1 deletion protocol/chainlib/consumer_ws_subscription_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type ConsumerWSSubscriptionManager struct {
activeSubscriptionProvidersStorage *lavasession.ActiveSubscriptionProvidersStorage
currentlyPendingSubscriptions map[string]*pendingSubscriptionsBroadcastManager
lock sync.RWMutex
consumerMetricsManager *metrics.ConsumerMetricsManager
}

func NewConsumerWSSubscriptionManager(
Expand All @@ -65,6 +66,7 @@ func NewConsumerWSSubscriptionManager(
connectionType string,
chainParser ChainParser,
activeSubscriptionProvidersStorage *lavasession.ActiveSubscriptionProvidersStorage,
consumerMetricsManager *metrics.ConsumerMetricsManager,
) *ConsumerWSSubscriptionManager {
return &ConsumerWSSubscriptionManager{
connectedDapps: make(map[string]map[string]*common.SafeChannelSender[*pairingtypes.RelayReply]),
Expand All @@ -76,6 +78,7 @@ func NewConsumerWSSubscriptionManager(
relaySender: relaySender,
connectionType: connectionType,
activeSubscriptionProvidersStorage: activeSubscriptionProvidersStorage,
consumerMetricsManager: consumerMetricsManager,
}
}

Expand Down Expand Up @@ -216,6 +219,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription(

// called after send relay failure or parsing failure afterwards
onSubscriptionFailure := func() {
go cwsm.consumerMetricsManager.SetFailedWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType)
cwsm.failedPendingSubscription(hashedParams)
closeWebsocketRepliesChannel()
}
Expand Down Expand Up @@ -255,6 +259,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription(
// Validated there are no active subscriptions that we can use.
firstSubscriptionReply, returnWebsocketRepliesChan := cwsm.checkForActiveSubscriptionWithLock(webSocketCtx, hashedParams, protocolMessage, dappKey, websocketRepliesSafeChannelSender, closeWebsocketRepliesChannel)
if firstSubscriptionReply != nil {
go cwsm.consumerMetricsManager.SetDuplicatedWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType)
if returnWebsocketRepliesChan {
return firstSubscriptionReply, websocketRepliesChan, nil
}
Expand Down Expand Up @@ -412,7 +417,7 @@ func (cwsm *ConsumerWSSubscriptionManager) StartSubscription(
cwsm.successfulPendingSubscription(hashedParams)
// Need to be run once for subscription
go cwsm.listenForSubscriptionMessages(webSocketCtx, dappID, consumerIp, replyServer, hashedParams, providerAddr, metricsData, closeSubscriptionChan)

go cwsm.consumerMetricsManager.SetWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType)
return &reply, websocketRepliesChan, nil
}

Expand Down Expand Up @@ -524,19 +529,22 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages(
utils.LogAttr("GUID", webSocketCtx),
utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)),
)
go cwsm.consumerMetricsManager.SetWsSubscriptioDisconnectRequestMetric(metricsData.ChainID, metricsData.APIType, metrics.WsDisconnectionReasonUser)
return
case <-replyServer.Context().Done():
utils.LavaFormatTrace("reply server context canceled",
utils.LogAttr("GUID", webSocketCtx),
utils.LogAttr("hashedParams", utils.ToHexString(hashedParams)),
)
go cwsm.consumerMetricsManager.SetWsSubscriptioDisconnectRequestMetric(metricsData.ChainID, metricsData.APIType, metrics.WsDisconnectionReasonConsumer)
return
default:
var reply pairingtypes.RelayReply
err := replyServer.RecvMsg(&reply)
if err != nil {
// The connection was closed by the provider
utils.LavaFormatTrace("error reading from subscription stream", utils.LogAttr("original error", err.Error()))
go cwsm.consumerMetricsManager.SetWsSubscriptioDisconnectRequestMetric(metricsData.ChainID, metricsData.APIType, metrics.WsDisconnectionReasonProvider)
return
}
err = cwsm.handleIncomingSubscriptionNodeMessage(hashedParams, &reply, providerAddr)
Expand All @@ -545,6 +553,7 @@ func (cwsm *ConsumerWSSubscriptionManager) listenForSubscriptionMessages(
utils.LogAttr("hashedParams", hashedParams),
utils.LogAttr("reply", reply),
)
go cwsm.consumerMetricsManager.SetFailedWsSubscriptionRequestMetric(metricsData.ChainID, metricsData.APIType)
return
}
}
Expand Down
42 changes: 26 additions & 16 deletions protocol/chainlib/consumer_ws_subscription_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package chainlib

import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -27,6 +29,9 @@ import (
const (
numberOfParallelSubscriptions = 10
uniqueId = "1234"
projectHashTest = "test_projecthash"
chainIdTest = "test_chainId"
apiTypeTest = "test_apiType"
)

func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *testing.T) {
Expand All @@ -51,7 +56,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes
subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`),
},
}

metricsData := metrics.NewRelayAnalytics(projectHashTest, chainIdTest, apiTypeTest)
for _, play := range playbook {
t.Run(play.name, func(t *testing.T) {
ts := SetupForTests(t, 1, play.specId, "../../")
Expand Down Expand Up @@ -136,7 +141,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes
consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String())

// Create a new ConsumerWSSubscriptionManager
manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage())
manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage(), nil)
uniqueIdentifiers := make([]string, numberOfParallelSubscriptions)
wg := sync.WaitGroup{}
wg.Add(numberOfParallelSubscriptions)
Expand All @@ -151,7 +156,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes
var repliesChan <-chan *pairingtypes.RelayReply
var firstReply *pairingtypes.RelayReply

firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[index], nil)
firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[index], metricsData)
go func() {
for subMsg := range repliesChan {
// utils.LavaFormatInfo("got reply for index", utils.LogAttr("index", index))
Expand All @@ -169,15 +174,15 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptionsOnSameDappIdIp(t *tes
// now we have numberOfParallelSubscriptions subscriptions currently running
require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions)
// remove one
err = manager.Unsubscribe(ts.Ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[0], nil)
err = manager.Unsubscribe(ts.Ctx, protocolMessage1, dapp, ip, uniqueIdentifiers[0], metricsData)
require.NoError(t, err)
// now we have numberOfParallelSubscriptions - 1
require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-1)
// check we still have an active subscription.
require.Len(t, manager.activeSubscriptions, 1)

// same flow for unsubscribe all
err = manager.UnsubscribeAll(ts.Ctx, dapp, ip, uniqueIdentifiers[1], nil)
err = manager.UnsubscribeAll(ts.Ctx, dapp, ip, uniqueIdentifiers[1], metricsData)
require.NoError(t, err)
// now we have numberOfParallelSubscriptions - 2
require.Len(t, manager.connectedDapps, numberOfParallelSubscriptions-2)
Expand Down Expand Up @@ -209,7 +214,6 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) {
subscriptionFirstReply2: []byte(`{"jsonrpc":"2.0","id":4,"result":{}}`),
},
}

for _, play := range playbook {
t.Run(play.name, func(t *testing.T) {
ts := SetupForTests(t, 1, play.specId, "../../")
Expand Down Expand Up @@ -291,9 +295,9 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) {
Times(1) // Should call SendParsedRelay, because it is the first time we subscribe

consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String())

metricsData := metrics.NewRelayAnalytics(projectHashTest, chainIdTest, apiTypeTest)
// Create a new ConsumerWSSubscriptionManager
manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage())
manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage(), nil)

wg := sync.WaitGroup{}
wg.Add(10)
Expand All @@ -305,7 +309,7 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) {
ctx := utils.WithUniqueIdentifier(ts.Ctx, utils.GenerateUniqueIdentifier())
var repliesChan <-chan *pairingtypes.RelayReply
var firstReply *pairingtypes.RelayReply
firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, nil)
firstReply, repliesChan, err = manager.StartSubscription(ctx, protocolMessage1, dapp+strconv.Itoa(index), ts.Consumer.Addr.String(), uniqueId, metricsData)
go func() {
for subMsg := range repliesChan {
require.Equal(t, string(play.subscriptionFirstReply1), string(subMsg.Data))
Expand All @@ -322,6 +326,11 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) {
}
}

func TestRateLimit(t *testing.T) {
numberOfRequests := &atomic.Uint64{}
fmt.Println(numberOfRequests.Load())
}

func TestConsumerWSSubscriptionManager(t *testing.T) {
// This test does the following:
// 1. Create a new ConsumerWSSubscriptionManager
Expand Down Expand Up @@ -379,6 +388,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {
unsubscribeMessage2: []byte(`{"jsonrpc":"2.0","method":"eth_unsubscribe","params":["0x2134567890"],"id":1}`),
},
}
metricsData := metrics.NewRelayAnalytics(projectHashTest, chainIdTest, apiTypeTest)

for _, play := range playbook {
t.Run(play.name, func(t *testing.T) {
Expand Down Expand Up @@ -538,12 +548,12 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {
consumerSessionManager := CreateConsumerSessionManager(play.specId, play.apiInterface, ts.Consumer.Addr.String())

// Create a new ConsumerWSSubscriptionManager
manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage())
manager := NewConsumerWSSubscriptionManager(consumerSessionManager, relaySender, nil, play.connectionType, chainParser, lavasession.NewActiveSubscriptionProvidersStorage(), nil)

// Start a new subscription for the first time, called SendParsedRelay once
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())

firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil)
firstReply, repliesChan1, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData)
assert.NoError(t, err)
unsubscribeMessageWg.Add(1)
assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data))
Expand All @@ -559,7 +569,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {

// Start a subscription again, same params, same dappKey, should not call SendParsedRelay
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, nil)
firstReply, repliesChan2, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData)
assert.NoError(t, err)
assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data))
assert.Nil(t, repliesChan2) // Same subscription, same dappKey, no need for a new channel
Expand All @@ -568,7 +578,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {

// Start a subscription again, same params, different dappKey, should not call SendParsedRelay
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp2, ts.Consumer.Addr.String(), uniqueId, nil)
firstReply, repliesChan3, err := manager.StartSubscription(ctx, subscribeProtocolMessage1, dapp2, ts.Consumer.Addr.String(), uniqueId, metricsData)
assert.NoError(t, err)
assert.Equal(t, string(play.subscriptionFirstReply1), string(firstReply.Data))
assert.NotNil(t, repliesChan3) // Same subscription, but different dappKey, so will create new channel
Expand Down Expand Up @@ -652,7 +662,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {
// Start a subscription again, different params, same dappKey, should call SendParsedRelay
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())

firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeProtocolMessage2, dapp1, ts.Consumer.Addr.String(), uniqueId, nil)
firstReply, repliesChan4, err := manager.StartSubscription(ctx, subscribeProtocolMessage2, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData)
assert.NoError(t, err)
unsubscribeMessageWg.Add(1)
assert.Equal(t, string(play.subscriptionFirstReply2), string(firstReply.Data))
Expand All @@ -671,7 +681,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {

ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
unsubProtocolMessage := NewProtocolMessage(unsubscribeChainMessage1, nil, relayResult1.Request.RelayData, dapp2, ts.Consumer.Addr.String())
err = manager.Unsubscribe(ctx, unsubProtocolMessage, dapp2, ts.Consumer.Addr.String(), uniqueId, nil)
err = manager.Unsubscribe(ctx, unsubProtocolMessage, dapp2, ts.Consumer.Addr.String(), uniqueId, metricsData)
require.NoError(t, err)

listenForExpectedMessages(ctx, repliesChan1, string(play.subscriptionFirstReply1))
Expand All @@ -697,7 +707,7 @@ func TestConsumerWSSubscriptionManager(t *testing.T) {
Times(2) // Should call SendParsedRelay, because it unsubscribed

ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
err = manager.UnsubscribeAll(ctx, dapp1, ts.Consumer.Addr.String(), uniqueId, nil)
err = manager.UnsubscribeAll(ctx, dapp1, ts.Consumer.Addr.String(), uniqueId, metricsData)
require.NoError(t, err)

expectNoMoreMessages(ctx, repliesChan1)
Expand Down
3 changes: 3 additions & 0 deletions protocol/common/cobra_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ const (
SetProviderOptimizerBestTierPickChance = "set-provider-optimizer-best-tier-pick-chance"
SetProviderOptimizerWorstTierPickChance = "set-provider-optimizer-worst-tier-pick-chance"
SetProviderOptimizerNumberOfTiersToCreate = "set-provider-optimizer-number-of-tiers-to-create"

// websocket flags
RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection"
)

const (
Expand Down
1 change: 1 addition & 0 deletions protocol/common/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
NODE_ERRORS_PROVIDERS_HEADER_NAME = "Lava-Node-Errors-providers"
REPORTED_PROVIDERS_HEADER_NAME = "Lava-Reported-Providers"
USER_REQUEST_TYPE = "lava-user-request-type"
STATEFUL_API_HEADER = "lava-stateful-api"
LAVA_IDENTIFIED_NODE_ERROR_HEADER = "lava-identified-node-error"
LAVAP_VERSION_HEADER_NAME = "Lavap-Version"
LAVA_CONSUMER_PROCESS_GUID = "lava-consumer-process-guid"
Expand Down
Loading

0 comments on commit c642cb2

Please sign in to comment.