diff --git a/pubsub.go b/pubsub.go index 3b74fa60..ae6aa96e 100644 --- a/pubsub.go +++ b/pubsub.go @@ -36,6 +36,8 @@ var ( var log = logging.Logger("pubsub") +type MatchingFunction func(string) func(string) bool + // PubSub is the implementation of the pubsub system. type PubSub struct { // atomic counter for seqnos @@ -157,6 +159,9 @@ type PubSub struct { // filter for tracking subscriptions in topics of interest; if nil, then we track all subscriptions subFilter SubscriptionFilter + // protoMatchFunc is a matching function for protocol selection. + protoMatchFunc *MatchingFunction + ctx context.Context } @@ -235,6 +240,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option peerOutboundQueueSize: 32, signID: h.ID(), signKey: nil, + protoMatchFunc: nil, signPolicy: StrictSign, incoming: make(chan *RPC, 32), newPeers: make(chan struct{}, 1), @@ -292,7 +298,11 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option rt.Attach(ps) for _, id := range rt.Protocols() { - h.SetStreamHandler(id, ps.handleNewStream) + if ps.protoMatchFunc != nil { + h.SetStreamHandlerMatch(id, (*ps.protoMatchFunc)(string(id)), ps.handleNewStream) + } else { + h.SetStreamHandler(id, ps.handleNewStream) + } } h.Network().Notify((*PubSubNotif)(ps)) @@ -475,6 +485,15 @@ func WithMaxMessageSize(maxMessageSize int) Option { } } +// WithProtocolMatchFunction sets a custom matching function for protocol +// selection to be used by the protocol handler on the Host's Mux +func WithProtocolMatchFunction(m MatchingFunction) Option { + return func(ps *PubSub) error { + ps.protoMatchFunc = &m + return nil + } +} + // processLoop handles all inputs arriving on the channels func (p *PubSub) processLoop(ctx context.Context) { defer func() {