diff --git a/node/pkg/common/guardianset.go b/node/pkg/common/guardianset.go index ea153a90d4..324c26c9db 100644 --- a/node/pkg/common/guardianset.go +++ b/node/pkg/common/guardianset.go @@ -214,3 +214,8 @@ func (st *GuardianSetState) Cleanup() { } } } + +// IsSubscribedToHeartbeats returns true if the heartbeat update channel is set. +func (st *GuardianSetState) IsSubscribedToHeartbeats() bool { + return st.updateC != nil +} diff --git a/node/pkg/common/guardianset_test.go b/node/pkg/common/guardianset_test.go index 0112c38eee..7b6e4c2630 100644 --- a/node/pkg/common/guardianset_test.go +++ b/node/pkg/common/guardianset_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + gossipv1 "github.com/certusone/wormhole/node/pkg/proto/gossip/v1" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" "github.com/wormhole-foundation/wormhole/sdk/vaa" @@ -122,3 +123,12 @@ func TestGet(t *testing.T) { gss.Set(&gs) assert.Equal(t, gss.Get(), &gs) } + +func TestIsSubscribedToHeartbeats(t *testing.T) { + heartbeatC := make(chan *gossipv1.Heartbeat, 20000) + gst1 := NewGuardianSetState(heartbeatC) + assert.True(t, gst1.IsSubscribedToHeartbeats()) + + gst2 := NewGuardianSetState(nil) + assert.False(t, gst2.IsSubscribedToHeartbeats()) +} diff --git a/node/pkg/p2p/p2p.go b/node/pkg/p2p/p2p.go index 6ff1e65703..b5ff281b22 100644 --- a/node/pkg/p2p/p2p.go +++ b/node/pkg/p2p/p2p.go @@ -362,7 +362,7 @@ func Run(params *RunParams) func(ctx context.Context) error { var controlSubscription, attestationSubscription, vaaSubscription *pubsub.Subscription // Set up the control channel. //////////////////////////////////////////////////////////////////// - if params.nodeName != "" || params.gossipControlSendC != nil || params.obsvReqSendC != nil || params.obsvReqRecvC != nil || params.signedGovCfgRecvC != nil || params.signedGovStatusRecvC != nil { + if params.nodeName != "" || params.gossipControlSendC != nil || params.obsvReqSendC != nil || params.obsvReqRecvC != nil || params.signedGovCfgRecvC != nil || params.signedGovStatusRecvC != nil || params.gst.IsSubscribedToHeartbeats() { controlTopic := fmt.Sprintf("%s/%s", params.networkID, "control") logger.Info("joining the control topic", zap.String("topic", controlTopic)) controlPubsubTopic, err = ps.Join(controlTopic) @@ -376,7 +376,7 @@ func Run(params *RunParams) func(ctx context.Context) error { } }() - if params.obsvReqRecvC != nil || params.signedGovCfgRecvC != nil || params.signedGovStatusRecvC != nil { + if params.obsvReqRecvC != nil || params.signedGovCfgRecvC != nil || params.signedGovStatusRecvC != nil || params.gst.IsSubscribedToHeartbeats() { logger.Info("subscribing to the control topic", zap.String("topic", controlTopic)) controlSubscription, err = controlPubsubTopic.Subscribe(pubsub.WithBufferSize(P2P_SUBSCRIPTION_BUFFER_SIZE)) if err != nil {