Skip to content

Commit

Permalink
feat: PRT - add rate limit to ws (#1713)
Browse files Browse the repository at this point in the history
* feat: PRT - add rate limit to ws

* lintush

* Update protocol/chainlib/consumer_websocket_manager.go
  • Loading branch information
ranlavanet authored Sep 25, 2024
1 parent bdb5f04 commit 7d0aefa
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 3 deletions.
58 changes: 55 additions & 3 deletions protocol/chainlib/consumer_websocket_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ package chainlib
import (
"context"
"strconv"
"sync/atomic"
"time"

gojson "github.com/goccy/go-json"
"github.com/goccy/go-json"
"github.com/gofiber/websocket/v2"
formatter "github.com/lavanet/lava/v3/ecosystem/cache/format"
"github.com/lavanet/lava/v3/protocol/common"
"github.com/lavanet/lava/v3/protocol/metrics"
"github.com/lavanet/lava/v3/utils"
"github.com/lavanet/lava/v3/utils/rand"
spectypes "github.com/lavanet/lava/v3/x/spec/types"
"github.com/tidwall/gjson"
)

var WebSocketRateLimit = -1 // rate limit requests per second on websocket connection

type ConsumerWebsocketManager struct {
websocketConn *websocket.Conn
rpcConsumerLogs *metrics.RPCConsumerLogs
Expand Down Expand Up @@ -67,6 +71,27 @@ func (cwm *ConsumerWebsocketManager) GetWebSocketConnectionUniqueId(dappId, user
return dappId + "__" + userIp + "__" + cwm.WebsocketConnectionUID
}

func (cwm *ConsumerWebsocketManager) handleRateLimitReached(inpData []byte) ([]byte, error) {
rateLimitError := common.JsonRpcRateLimitError
id := 0
result := gjson.GetBytes(inpData, "id")
switch result.Type {
case gjson.Number:
id = int(result.Int())
case gjson.String:
idParsed, err := strconv.Atoi(result.Raw)
if err == nil {
id = idParsed
}
}
rateLimitError.Id = id
bytesRateLimitError, err := json.Marshal(rateLimitError)
if err != nil {
return []byte{}, utils.LavaFormatError("failed marshalling jsonrpc rate limit error", err)
}
return bytesRateLimitError, nil
}

func (cwm *ConsumerWebsocketManager) ListenToMessages() {
var (
messageType int
Expand Down Expand Up @@ -110,6 +135,24 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
}
}()

// rate limit routine
requestsPerSecond := &atomic.Uint64{}
go func() {
if WebSocketRateLimit <= 0 {
return
}
ticker := time.NewTicker(time.Second) // rate limit per second.
defer ticker.Stop()
for {
select {
case <-webSocketCtx.Done():
return
case <-ticker.C:
requestsPerSecond.Store(0)
}
}
}()

for {
startTime := time.Now()
msgSeed := guidString + "_" + strconv.Itoa(rand.Intn(10000000000)) // use message seed with original guid and new int
Expand All @@ -125,6 +168,15 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
break
}

// Check rate limit is met
if WebSocketRateLimit > 0 && requestsPerSecond.Add(1) > uint64(WebSocketRateLimit) {
rateLimitResponse, err := cwm.handleRateLimitReached(msg)
if err == nil {
websocketConnWriteChan <- webSocketMsgWithType{messageType: messageType, msg: rateLimitResponse}
}
continue
}

dappID, ok := websocketConn.Locals("dapp-id").(string)
if !ok {
// Log and remove the analyze
Expand Down Expand Up @@ -167,7 +219,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {
if err != nil {
utils.LavaFormatWarning("error unsubscribing from subscription", err, utils.LogAttr("GUID", webSocketCtx))
if err == common.SubscriptionNotFoundError {
msgData, err := gojson.Marshal(common.JsonRpcSubscriptionNotFoundError)
msgData, err := json.Marshal(common.JsonRpcSubscriptionNotFoundError)
if err != nil {
continue
}
Expand Down Expand Up @@ -224,7 +276,7 @@ func (cwm *ConsumerWebsocketManager) ListenToMessages() {

// Handle the case when the error is a method not found error
if common.APINotSupportedError.Is(err) {
msgData, err := gojson.Marshal(common.JsonRpcMethodNotFoundError)
msgData, err := json.Marshal(common.JsonRpcMethodNotFoundError)
if err != nil {
continue
}
Expand Down
7 changes: 7 additions & 0 deletions protocol/chainlib/consumer_ws_subscription_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package chainlib

import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -324,6 +326,11 @@ func TestConsumerWSSubscriptionManagerParallelSubscriptions(t *testing.T) {
}
}

func TestRateLimit(t *testing.T) {
numberOfRequests := &atomic.Uint64{}
fmt.Println(numberOfRequests.Load())
}

func TestConsumerWSSubscriptionManager(t *testing.T) {
// This test does the following:
// 1. Create a new ConsumerWSSubscriptionManager
Expand Down
3 changes: 3 additions & 0 deletions protocol/common/cobra_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ const (
SetProviderOptimizerBestTierPickChance = "set-provider-optimizer-best-tier-pick-chance"
SetProviderOptimizerWorstTierPickChance = "set-provider-optimizer-worst-tier-pick-chance"
SetProviderOptimizerNumberOfTiersToCreate = "set-provider-optimizer-number-of-tiers-to-create"

// websocket flags
RateLimitWebSocketFlag = "rate-limit-websocket-requests-per-connection"
)

const (
Expand Down
9 changes: 9 additions & 0 deletions protocol/common/return_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ var JsonRpcMethodNotFoundError = JsonRPCErrorMessage{
},
}

var JsonRpcRateLimitError = JsonRPCErrorMessage{
JsonRPC: "2.0",
Id: 1,
Error: JsonRPCError{
Code: 429,
Message: "Too Many Requests",
},
}

var JsonRpcSubscriptionNotFoundError = JsonRPCErrorMessage{
JsonRPC: "2.0",
Id: 1,
Expand Down
1 change: 1 addition & 0 deletions protocol/rpcconsumer/rpcconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77
cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.ATierChance, common.SetProviderOptimizerBestTierPickChance, 0.75, "set the chances for picking a provider from the best group, default is 75% -> 0.75")
cmdRPCConsumer.Flags().Float64Var(&provideroptimizer.LastTierChance, common.SetProviderOptimizerWorstTierPickChance, 0.0, "set the chances for picking a provider from the worse group, default is 0% -> 0.0")
cmdRPCConsumer.Flags().IntVar(&provideroptimizer.OptimizerNumTiers, common.SetProviderOptimizerNumberOfTiersToCreate, 4, "set the number of groups to create, default is 4")
cmdRPCConsumer.Flags().IntVar(&chainlib.WebSocketRateLimit, common.RateLimitWebSocketFlag, chainlib.WebSocketRateLimit, "rate limit (per second) websocket requests per user connection, default is unlimited")
common.AddRollingLogConfig(cmdRPCConsumer)
return cmdRPCConsumer
}
Expand Down

0 comments on commit 7d0aefa

Please sign in to comment.