diff --git a/pubsub.go b/pubsub.go index dd058645..c9ddc3be 100644 --- a/pubsub.go +++ b/pubsub.go @@ -65,9 +65,6 @@ type PubSub struct { // incoming messages from other peers incoming chan *RPC - // messages we are publishing out to our peers - publish chan *Message - // addSub is a control channel for us to add and remove subscriptions addSub chan *addSubReq @@ -234,7 +231,6 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option signKey: nil, signPolicy: StrictSign, incoming: make(chan *RPC, 32), - publish: make(chan *Message), newPeers: make(chan peer.ID), newPeerStream: make(chan network.Stream), newPeerError: make(chan peer.ID), @@ -589,10 +585,6 @@ func (p *PubSub) processLoop(ctx context.Context) { case rpc := <-p.incoming: p.handleIncomingRPC(rpc) - case msg := <-p.publish: - p.tracer.PublishMessage(msg) - p.pushMsg(msg) - case msg := <-p.sendMsg: p.publishMessage(msg) @@ -999,22 +991,51 @@ func (p *PubSub) pushMsg(msg *Message) { return } + err := p.checkSignature(msg) + if err != nil { + log.Debugf("dropping message from %s: %s", src, err) + return + } + + // reject messages claiming to be from ourselves but not locally published + self := p.host.ID() + if peer.ID(msg.GetFrom()) == self && src != self { + log.Debugf("dropping message claiming to be from self but forwarded from %s", src) + p.tracer.RejectMessage(msg, RejectSelfOrigin) + return + } + + // have we already seen and validated this message? + id := p.msgID(msg.Message) + if p.seenMessage(id) { + p.tracer.DuplicateMessage(msg) + return + } + + if !p.val.Push(src, msg) { + return + } + + if p.markSeen(id) { + p.publishMessage(msg) + } +} + +func (p *PubSub) checkSignature(msg *Message) error { // reject unsigned messages when strict before we even process the id if p.signPolicy.mustVerify() { if p.signPolicy.mustSign() { if msg.Signature == nil { - log.Debugf("dropping unsigned message from %s", src) p.tracer.RejectMessage(msg, RejectMissingSignature) - return + return ValidationError{Reason: RejectMissingSignature} } // Actual signature verification happens in the validation pipeline, // after checking if the message was already seen or not, // to avoid unnecessary signature verification processing-cost. } else { if msg.Signature != nil { - log.Debugf("dropping message with unexpected signature from %s", src) p.tracer.RejectMessage(msg, RejectUnexpectedSignature) - return + return ValidationError{Reason: RejectUnexpectedSignature} } // If we are expecting signed messages, and not authoring messages, // then do no accept seq numbers, from data, or key data. @@ -1022,36 +1043,14 @@ func (p *PubSub) pushMsg(msg *Message) { // but is not used if we are not authoring messages ourselves. if p.signID == "" { if msg.Seqno != nil || msg.From != nil || msg.Key != nil { - log.Debugf("dropping message with unexpected auth info from %s", src) p.tracer.RejectMessage(msg, RejectUnexpectedAuthInfo) - return + return ValidationError{Reason: RejectUnexpectedAuthInfo} } } } } - // reject messages claiming to be from ourselves but not locally published - self := p.host.ID() - if peer.ID(msg.GetFrom()) == self && src != self { - log.Debugf("dropping message claiming to be from self but forwarded from %s", src) - p.tracer.RejectMessage(msg, RejectSelfOrigin) - return - } - - // have we already seen and validated this message? - id := p.msgID(msg.Message) - if p.seenMessage(id) { - p.tracer.DuplicateMessage(msg) - return - } - - if !p.val.Push(src, msg) { - return - } - - if p.markSeen(id) { - p.publishMessage(msg) - } + return nil } func (p *PubSub) publishMessage(msg *Message) { diff --git a/topic.go b/topic.go index 91a1ff15..e2e7b8ac 100644 --- a/topic.go +++ b/topic.go @@ -241,13 +241,7 @@ func (t *Topic) Publish(ctx context.Context, data []byte, opts ...PubOpt) error t.p.disc.Bootstrap(ctx, t.topic, pub.ready) } - select { - case t.p.publish <- &Message{m, t.p.host.ID(), nil}: - case <-t.p.ctx.Done(): - return t.p.ctx.Err() - } - - return nil + return t.p.val.Publish(&Message{m, t.p.host.ID(), nil}) } // WithReadiness returns a publishing option for only publishing when the router is ready. diff --git a/validation.go b/validation.go index 328b53a0..0a152625 100644 --- a/validation.go +++ b/validation.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "runtime" + "sync" "time" "github.com/libp2p/go-libp2p-core/peer" @@ -15,6 +16,16 @@ const ( defaultValidateThrottle = 8192 ) +// ValidationError is an error that may be signalled from message publication when the message +// fails validation +type ValidationError struct { + Reason string +} + +func (e ValidationError) Error() string { + return e.Reason +} + // Validator is a function that validates a message with a binary decision: accept or reject. type Validator func(context.Context, peer.ID, *Message) bool @@ -56,6 +67,8 @@ type validation struct { tracer *pubsubTracer + // mx protects the validator map + mx sync.Mutex // topicVals tracks per topic validators topicVals map[string]*topicVal @@ -123,6 +136,9 @@ func (v *validation) Start(p *PubSub) { // AddValidator adds a new validator func (v *validation) AddValidator(req *addValReq) { + v.mx.Lock() + defer v.mx.Unlock() + topic := req.topic _, ok := v.topicVals[topic] @@ -180,6 +196,9 @@ func (v *validation) AddValidator(req *addValReq) { // RemoveValidator removes an existing validator func (v *validation) RemoveValidator(req *rmValReq) { + v.mx.Lock() + defer v.mx.Unlock() + topic := req.topic _, ok := v.topicVals[topic] @@ -191,6 +210,20 @@ func (v *validation) RemoveValidator(req *rmValReq) { } } +// Publish synchronously accepts a locally published message, performs applicable +// validations and pushes the message for propagate by the pubsub system +func (v *validation) Publish(msg *Message) error { + v.p.tracer.PublishMessage(msg) + + err := v.p.checkSignature(msg) + if err != nil { + return err + } + + vals := v.getValidators(msg) + return v.validate(vals, msg.ReceivedFrom, msg, true) +} + // Push pushes a message into the validation pipeline. // It returns true if the message can be forwarded immediately without validation. func (v *validation) Push(src peer.ID, msg *Message) bool { @@ -211,6 +244,9 @@ func (v *validation) Push(src peer.ID, msg *Message) bool { // getValidators returns all validators that apply to a given message func (v *validation) getValidators(msg *Message) []*topicVal { + v.mx.Lock() + defer v.mx.Unlock() + topic := msg.GetTopic() val, ok := v.topicVals[topic] @@ -226,7 +262,7 @@ func (v *validation) validateWorker() { for { select { case req := <-v.validateQ: - v.validate(req.vals, req.src, req.msg) + v.validate(req.vals, req.src, req.msg, false) case <-v.p.ctx.Done(): return } @@ -234,16 +270,14 @@ func (v *validation) validateWorker() { } // validate performs validation and only sends the message if all validators succeed -// signature validation is performed synchronously, while user validators are invoked -// asynchronously, throttled by the global validation throttle. -func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) { +func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synchronous bool) error { // If signature verification is enabled, but signing is disabled, // the Signature is required to be nil upon receiving the message in PubSub.pushMsg. if msg.Signature != nil { if !v.validateSignature(msg) { log.Debugf("message signature validation failed; dropping message from %s", src) v.tracer.RejectMessage(msg, RejectInvalidSignature) - return + return ValidationError{Reason: RejectInvalidSignature} } } @@ -252,14 +286,14 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) { id := v.p.msgID(msg.Message) if !v.p.markSeen(id) { v.tracer.DuplicateMessage(msg) - return + return nil } else { v.tracer.ValidateMessage(msg) } var inline, async []*topicVal for _, val := range vals { - if val.validateInline { + if val.validateInline || synchronous { inline = append(inline, val) } else { async = append(async, val) @@ -283,7 +317,7 @@ loop: if result == ValidationReject { log.Debugf("message validation failed; dropping message from %s", src) v.tracer.RejectMessage(msg, RejectValidationFailed) - return + return ValidationError{Reason: RejectValidationFailed} } // apply async validators @@ -298,16 +332,21 @@ loop: log.Debugf("message validation throttled; dropping message from %s", src) v.tracer.RejectMessage(msg, RejectValidationThrottled) } - return + return nil } if result == ValidationIgnore { v.tracer.RejectMessage(msg, RejectValidationIgnored) - return + return ValidationError{Reason: RejectValidationIgnored} } // no async validators, accepted message, send it! - v.p.sendMsg <- msg + select { + case v.p.sendMsg <- msg: + return nil + case <-v.p.ctx.Done(): + return v.p.ctx.Err() + } } func (v *validation) validateSignature(msg *Message) bool {