diff --git a/gossipsub.go b/gossipsub.go index 849f2fa1..b67b47c6 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -68,7 +68,8 @@ var ( GossipSubGraftFloodThreshold = 10 * time.Second GossipSubMaxIHaveLength = 5000 GossipSubMaxIHaveMessages = 10 - GossipSubMaxIDontWantMessages = 1000 + GossipSubMaxIDontWantLength = 100 + GossipSubMaxIDontWantMessages = 100 GossipSubIWantFollowupTime = 3 * time.Second GossipSubIDontWantMessageThreshold = 1024 // 1KB GossipSubIDontWantMessageTTL = 3 // 3 heartbeats @@ -218,6 +219,10 @@ type GossipSubParams struct { // MaxIHaveMessages is the maximum number of IHAVE messages to accept from a peer within a heartbeat. MaxIHaveMessages int + // MaxIDontWantLength is the maximum number of messages to include in an IDONTWANT message. Also controls + // the maximum number of IDONTWANT ids we will accept to protect against IDONTWANT floods. This value + // should be adjusted if your system anticipates a larger amount than specified per heartbeat. + MaxIDontWantLength int // MaxIDontWantMessages is the maximum number of IDONTWANT messages to accept from a peer within a heartbeat. MaxIDontWantMessages int @@ -303,6 +308,7 @@ func DefaultGossipSubParams() GossipSubParams { GraftFloodThreshold: GossipSubGraftFloodThreshold, MaxIHaveLength: GossipSubMaxIHaveLength, MaxIHaveMessages: GossipSubMaxIHaveMessages, + MaxIDontWantLength: GossipSubMaxIDontWantLength, MaxIDontWantMessages: GossipSubMaxIDontWantMessages, IWantFollowupTime: GossipSubIWantFollowupTime, IDontWantMessageThreshold: GossipSubIDontWantMessageThreshold, @@ -1009,9 +1015,18 @@ func (gs *GossipSubRouter) handleIDontWant(p peer.ID, ctl *pb.ControlMessage) { } gs.peerdontwant[p]++ + totalUnwantedIds := 0 // Remember all the unwanted message ids +mainIDWLoop: for _, idontwant := range ctl.GetIdontwant() { for _, mid := range idontwant.GetMessageIDs() { + // IDONTWANT flood protection + if totalUnwantedIds >= gs.params.MaxIDontWantLength { + log.Debugf("IDONWANT: peer %s has advertised too many ids (%d) within this message; ignoring", p, totalUnwantedIds) + break mainIDWLoop + } + + totalUnwantedIds++ gs.unwanted[p][computeChecksum(mid)] = gs.params.IDontWantMessageTTL } } diff --git a/gossipsub_spam_test.go b/gossipsub_spam_test.go index df2fffff..9f6f0f94 100644 --- a/gossipsub_spam_test.go +++ b/gossipsub_spam_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/base64" + "fmt" "strconv" "sync" "testing" @@ -891,6 +892,53 @@ func TestGossipsubAttackSpamIDONTWANT(t *testing.T) { <-ctx.Done() } +func TestGossipsubHandleIDontwantSpam(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hosts := getDefaultHosts(t, 2) + + msgID := func(pmsg *pb.Message) string { + // silly content-based test message-ID: just use the data as whole + return base64.URLEncoding.EncodeToString(pmsg.Data) + } + + psubs := make([]*PubSub, 2) + psubs[0] = getGossipsub(ctx, hosts[0], WithMessageIdFn(msgID)) + psubs[1] = getGossipsub(ctx, hosts[1], WithMessageIdFn(msgID)) + + connect(t, hosts[0], hosts[1]) + + topic := "foobar" + for _, ps := range psubs { + _, err := ps.Subscribe(topic) + if err != nil { + t.Fatal(err) + } + } + exceededIDWLength := GossipSubMaxIDontWantLength + 1 + var idwIds []string + for i := 0; i < exceededIDWLength; i++ { + idwIds = append(idwIds, fmt.Sprintf("idontwant-%d", i)) + } + rPid := hosts[1].ID() + ctrlMessage := &pb.ControlMessage{Idontwant: []*pb.ControlIDontWant{{MessageIDs: idwIds}}} + grt := psubs[0].rt.(*GossipSubRouter) + grt.handleIDontWant(rPid, ctrlMessage) + + if grt.peerdontwant[rPid] != 1 { + t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid]) + } + mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1) + if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok { + t.Errorf("Desired message id was not stored in the unwanted map: %s", mid) + } + + mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength) + if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok { + t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid) + } +} + type mockGSOnRead func(writeMsg func(*pb.RPC), irpc *pb.RPC) func newMockGS(ctx context.Context, t *testing.T, attacker host.Host, onReadMsg mockGSOnRead) {