diff --git a/protocol/chainlib/consumer_websocket_manager.go b/protocol/chainlib/consumer_websocket_manager.go index a3bd553424..e6edb7aaa6 100644 --- a/protocol/chainlib/consumer_websocket_manager.go +++ b/protocol/chainlib/consumer_websocket_manager.go @@ -3,9 +3,10 @@ package chainlib import ( "context" "strconv" + "sync/atomic" "time" - gojson "github.com/goccy/go-json" + "github.com/goccy/go-json" "github.com/gofiber/websocket/v2" formatter "github.com/lavanet/lava/v3/ecosystem/cache/format" "github.com/lavanet/lava/v3/protocol/common" @@ -13,8 +14,11 @@ import ( "github.com/lavanet/lava/v3/utils" "github.com/lavanet/lava/v3/utils/rand" spectypes "github.com/lavanet/lava/v3/x/spec/types" + "github.com/tidwall/gjson" ) +var WebSocketRateLimit = -1 // rate limit requests per second on websocket connection + type ConsumerWebsocketManager struct { websocketConn *websocket.Conn rpcConsumerLogs *metrics.RPCConsumerLogs @@ -67,6 +71,27 @@ func (cwm *ConsumerWebsocketManager) GetWebSocketConnectionUniqueId(dappId, user return dappId + "__" + userIp + "__" + cwm.WebsocketConnectionUID } +func (cwm *ConsumerWebsocketManager) handleRateLimitReached(inpData []byte) ([]byte, error) { + rateLimitError := common.JsonRpcRateLimitError + id := 0 + result := gjson.GetBytes(inpData, "id") + switch result.Type { + case gjson.Number: + id = int(result.Int()) + case gjson.String: + idParsed, err := strconv.Atoi(result.Raw) + if err == nil { + id = idParsed + } + } + rateLimitError.Id = id + bytesRateLimitError, err := json.Marshal(rateLimitError) + if err != nil { + return []byte{}, utils.LavaFormatError("failed marshalling jsonrpc rate limit error", err) + } + return bytesRateLimitError, nil +} + func (cwm *ConsumerWebsocketManager) ListenToMessages() { var ( messageType int @@ -110,6 +135,24 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { } }() + // rate limit routine + requestsPerSecond := &atomic.Uint64{} + go func() { + if WebSocketRateLimit <= 0 { + return + } + ticker := time.NewTicker(time.Second) // rate limit per second. + defer ticker.Stop() + for { + select { + case <-webSocketCtx.Done(): + return + case <-ticker.C: + requestsPerSecond.Store(0) + } + } + }() + for { startTime := time.Now() msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int @@ -125,6 +168,15 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { break } + // Check rate limit is met + if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) { + rateLimitResponse, err := cwm.handleRateLimitReached(msg) + if err == nil { + websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse} + } + continue + } + dappID, ok := websocketConn.Locals("dapp-id").(string) if !ok { // Log and remove the analyze @@ -167,7 +219,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { if err != nil { utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx)) if err == common.SubscriptionNotFoundError { - msgData, err := gojson.Marshal(common.JsonRpcSubscriptionNotFoundError) + msgData, err := json.Marshal(common.JsonRpcSubscriptionNotFoundError) if err != nil { continue } @@ -224,7 +276,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() { // Handle the case when the error is a method not found error if common.APINotSupportedError.Is(err) { - msgData, err := gojson.Marshal(common.JsonRpcMethodNotFoundError) + msgData, err := json.Marshal(common.JsonRpcMethodNotFoundError) if err != nil { continue } diff --git a/protocol/chainlib/consumer_ws_subscription_manager_test.go b/protocol/chainlib/consumer_ws_subscription_manager_test.go index 9aebc649a4..48573a3512 100644 --- a/protocol/chainlib/consumer_ws_subscription_manager_test.go +++ b/protocol/chainlib/consumer_ws_subscription_manager_test.go @@ -2,9 +2,11 @@ package chainlib import ( "context" + "fmt" "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -324,6 +326,11 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) { } } +func TestRateLimit(t *testing.T) { + numberOfRequests := &atomic.Uint64{} + fmt.Println(numberOfRequests.Load()) +} + func TestConsumerWSSubscriptionManager(t *testing.T) { // This test does the following: // 1. Create a new ConsumerWSSubscriptionManager diff --git a/protocol/common/cobra_common.go b/protocol/common/cobra_common.go index 40cbffdce2..fe75c8f31f 100644 --- a/protocol/common/cobra_common.go +++ b/protocol/common/cobra_common.go @@ -38,6 +38,9 @@ const ( SetProviderOptimizerBestTierPickChance = "set-provider-optimizer-best-tier-pick-chance" SetProviderOptimizerWorstTierPickChance = "set-provider-optimizer-worst-tier-pick-chance" SetProviderOptimizerNumberOfTiersToCreate = "set-provider-optimizer-number-of-tiers-to-create" + + // websocket flags + RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection" ) const ( diff --git a/protocol/common/return_errors.go b/protocol/common/return_errors.go index 5394ba1f3d..9020a26f17 100644 --- a/protocol/common/return_errors.go +++ b/protocol/common/return_errors.go @@ -27,6 +27,15 @@ var JsonRpcMethodNotFoundError = JsonRPCErrorMessage{ }, } +var JsonRpcRateLimitError = JsonRPCErrorMessage{ + JsonRPC: "2.0", + Id: 1, + Error: JsonRPCError{ + Code: 429, + Message: "Too Many Requests", + }, +} + var JsonRpcSubscriptionNotFoundError = JsonRPCErrorMessage{ JsonRPC: "2.0", Id: 1, diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 1223ef34be..77b07b4bf1 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -618,6 +618,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77 cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.ATierChance, common.SetProviderOptimizerBestTierPickChance, 0.75, "set the chances for picking a provider from the best group, default is 75% -> 0.75") cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.LastTierChance, common.SetProviderOptimizerWorstTierPickChance, 0.0, "set the chances for picking a provider from the worse group, default is 0% -> 0.0") cmdRPCConsumer.Flags().IntVar(&provideroptimizer.OptimizerNumTiers, common.SetProviderOptimizerNumberOfTiersToCreate, 4, "set the number of groups to create, default is 4") + cmdRPCConsumer.Flags().IntVar(&chainlib.WebSocketRateLimit, common.RateLimitWebSocketFlag, chainlib.WebSocketRateLimit, "rate limit (per second) websocket requests per user connection, default is unlimited") common.AddRollingLogConfig(cmdRPCConsumer) return cmdRPCConsumer }