Skip to content

Commit

Permalink
Support multiple signers in weakcoin
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Oct 17, 2023
1 parent 26471f0 commit 7821cd7
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 155 deletions.
133 changes: 56 additions & 77 deletions beacon/beacon.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,9 @@ func WithConfig(cfg Config) Opt {
}
}

type weakCoinFactory func(signer *signing.EdSigner) coin

func withWeakCoinFactory(f weakCoinFactory) Opt {
func withWeakCoin(wc coin) Opt {
return func(pd *ProtocolDriver) {
pd.createWeakCoin = f
pd.weakCoin = wc
}
}

Expand All @@ -103,8 +101,7 @@ func New(
nonceFetcher: cdb,
cdb: cdb,
clock: clock,
signers: make(map[types.NodeID]participant),
createWeakCoin: nil,
signers: make(map[types.NodeID]*signing.EdSigner),
beacons: make(map[types.EpochID]types.Beacon),
ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight),
states: make(map[types.EpochID]*state),
Expand All @@ -121,20 +118,18 @@ func New(
pd.ctx, pd.cancel = context.WithCancel(pd.ctx)
pd.theta = new(big.Float).SetRat(pd.config.Theta)

if pd.createWeakCoin == nil {
pd.createWeakCoin = func(signer *signing.EdSigner) coin {
return weakcoin.New(
pd.publisher,
signer.VRFSigner(),
pd.vrfVerifier,
pd.nonceFetcher,
pd,
pd.msgTimes,
weakcoin.WithLog(pd.logger.WithName("weakCoin").WithName(signer.NodeID().ShortString())),
weakcoin.WithMaxRound(pd.config.RoundsNumber),
)
}
if pd.weakCoin == nil {
pd.weakCoin = weakcoin.New(
pd.publisher,
pd.vrfVerifier,
pd.nonceFetcher,
pd,
pd.msgTimes,
weakcoin.WithLog(pd.logger.WithName("weakCoin")),
weakcoin.WithMaxRound(pd.config.RoundsNumber),
)
}

pd.metricsCollector = metrics.NewBeaconMetricsCollector(pd.gatherMetricsData, pd.logger.WithName("metrics"))
return pd
}
Expand All @@ -143,35 +138,21 @@ func (pd *ProtocolDriver) Register(s *signing.EdSigner) {
pd.mu.Lock()
defer pd.mu.Unlock()
if _, exists := pd.signers[s.NodeID()]; exists {
pd.logger.With().Error("signing key already registered", log.ShortStringer("key", s.NodeID()))
pd.logger.With().Error("signing key already registered", log.ShortStringer("id", s.NodeID()))
return
}
p := participant{
signer: s,
coin: pd.createWeakCoin(s),
}
pd.logger.With().Info("registered signing key", p.Id())
pd.signers[s.NodeID()] = p

pd.logger.With().Info("registered signing key", log.ShortStringer("id", s.NodeID()))
pd.signers[s.NodeID()] = s
}

type participant struct {
signer *signing.EdSigner
coin coin
}

type signerSession struct {
participant
nonce types.VRFPostIndex
nonce types.VRFPostIndex
}

func (s *participant) Id() signerID {
return signerID(s.signer.NodeID())
}

type signerID types.NodeID

func (id signerID) Field() log.Field {
return log.ShortStringer("id", types.NodeID(id))
func (s *participant) Id() log.Field {
return log.ShortStringer("id", s.signer.NodeID())
}

// ProtocolDriver is the driver for the beacon protocol.
Expand All @@ -187,8 +168,8 @@ type ProtocolDriver struct {
sync system.SyncStateProvider
publisher pubsub.Publisher

signers map[types.NodeID]participant
createWeakCoin weakCoinFactory
signers map[types.NodeID]*signing.EdSigner
weakCoin coin

edVerifier *signing.EdVerifier
vrfVerifier vrfVerifier
Expand Down Expand Up @@ -563,7 +544,7 @@ func (pd *ProtocolDriver) initEpochStateIfNotPresent(logger log.Log, epoch types
var (
epochWeight uint64
miners = make(map[types.NodeID]*minerInfo)
potentiallyActive = make(map[types.NodeID]participant)
potentiallyActive = make(map[types.NodeID]*signing.EdSigner)
// w1 is the weight units at δ before the end of the previous epoch, used to calculate `thresholdStrict`
// w2 is the weight units at the end of the previous epoch, used to calculate `threshold`
w1, w2 int
Expand Down Expand Up @@ -609,14 +590,14 @@ func (pd *ProtocolDriver) initEpochStateIfNotPresent(logger log.Log, epoch types
return nil, errZeroEpochWeight
}

active := map[types.NodeID]signerSession{}
active := map[types.NodeID]participant{}
for id, signer := range potentiallyActive {
if nnc, err := pd.nonceFetcher.VRFNonce(id, epoch); err != nil {
if nonce, err := pd.nonceFetcher.VRFNonce(id, epoch); err != nil {
logger.With().Error("getting own VRF nonce", id, log.Err(err))
} else {
active[id] = signerSession{
participant: signer,
nonce: nnc,
active[id] = participant{
signer: signer,
nonce: nonce,
}
}
}
Expand Down Expand Up @@ -759,10 +740,8 @@ func (pd *ProtocolDriver) runProtocol(ctx context.Context, epoch types.EpochID,
pd.setBeginProtocol(ctx)
defer pd.setEndProtocol(ctx)

for _, participant := range st.active {
participant.coin.StartEpoch(ctx, epoch)
defer participant.coin.FinishEpoch(ctx, epoch)
}
pd.weakCoin.StartEpoch(ctx, epoch)
defer pd.weakCoin.FinishEpoch(ctx, epoch)

if err := pd.runProposalPhase(ctx, epoch, st); err != nil {
logger.With().Warning("proposal phase failed", log.Err(err))
Expand Down Expand Up @@ -831,7 +810,7 @@ func (pd *ProtocolDriver) runProposalPhase(ctx context.Context, epoch types.Epoc
return nil
}

func (pd *ProtocolDriver) sendProposal(ctx context.Context, epoch types.EpochID, s signerSession, checker eligibilityChecker) {
func (pd *ProtocolDriver) sendProposal(ctx context.Context, epoch types.EpochID, s participant, checker eligibilityChecker) {
if pd.isClosed() {
return
}
Expand Down Expand Up @@ -878,7 +857,6 @@ func (pd *ProtocolDriver) runConsensusPhase(ctx context.Context, epoch types.Epo
var (
ownVotes allVotes
undecided proposalList
err error
)

// First round
Expand Down Expand Up @@ -937,31 +915,33 @@ func (pd *ProtocolDriver) runConsensusPhase(ctx context.Context, epoch types.Epo
ownVotes, undecided = pd.calcVotesBeforeWeakCoin(rLogger, st)

timer.Reset(pd.config.WeakCoinRoundDuration)
for _, session := range st.active {
session := session
pd.eg.Go(func() error {
session.coin.StartRound(ctx, round, &session.nonce)
return nil
})
}

pd.eg.Go(func() error {
participants := make([]weakcoin.Participant, 0, len(st.active))
for _, session := range st.active {
participants = append(participants, weakcoin.Participant{
Signer: session.signer.VRFSigner(),
Nonce: session.nonce,
})
}
pd.weakCoin.StartRound(ctx, round, participants)
return nil
})

select {
case <-timer.C:
case <-ctx.Done():
return allVotes{}, fmt.Errorf("context done: %w", ctx.Err())
}
for _, session := range st.active {
session.coin.FinishRound(ctx)
}
// All weak coin should have the same result, so we can just take the first one
var flip bool
for _, session := range st.active {
flip, err = session.coin.Get(ctx, epoch, round)
if err != nil {
rLogger.With().Error("failed to generate weak coin", log.Err(err))
return allVotes{}, err
}
break

pd.weakCoin.FinishRound(ctx)

flip, err := pd.weakCoin.Get(ctx, epoch, round)
if err != nil {
rLogger.With().Error("failed to generate weak coin", log.Err(err))
return allVotes{}, err
}

tallyUndecided(&ownVotes, undecided, flip)
}

Expand All @@ -978,7 +958,6 @@ func (pd *ProtocolDriver) markProposalPhaseFinished(st *state, finishedAt time.T
func (pd *ProtocolDriver) calcVotesBeforeWeakCoin(logger log.Log, st *state) (allVotes, proposalList) {
pd.mu.RLock()
defer pd.mu.RUnlock()

return calcVotes(logger, pd.theta, st)
}

Expand All @@ -989,7 +968,7 @@ func (pd *ProtocolDriver) sendFirstRoundVote(ctx context.Context, msg FirstVotin
Signature: signer.Sign(signing.BEACON_FIRST_MSG, codec.MustEncode(&msg)),
}

pd.logger.WithContext(ctx).With().Debug("sending first round vote", msg.EpochID, types.FirstRound, signerID(signer.NodeID()))
pd.logger.WithContext(ctx).With().Debug("sending first round vote", msg.EpochID, types.FirstRound, log.ShortStringer("id", signer.NodeID()))
return pd.sendToGossip(ctx, pubsub.BeaconFirstVotesProtocol, codec.MustEncode(&m))
}

Expand Down Expand Up @@ -1024,7 +1003,7 @@ func (pd *ProtocolDriver) sendFollowingVote(ctx context.Context, epoch types.Epo
Signature: signer.Sign(signing.BEACON_FOLLOWUP_MSG, codec.MustEncode(&mb)),
}

pd.logger.WithContext(ctx).With().Debug("sending following round vote", epoch, round, signerID(signer.NodeID()))
pd.logger.WithContext(ctx).With().Debug("sending following round vote", epoch, round, log.ShortStringer("id", signer.NodeID()))
return pd.sendToGossip(ctx, pubsub.BeaconFollowingVotesProtocol, codec.MustEncode(&m))
}

Expand Down Expand Up @@ -1117,7 +1096,7 @@ func buildSignedProposal(ctx context.Context, logger log.Log, signer vrfSigner,
p := buildProposal(logger, epoch, nonce)
vrfSig := signer.Sign(p)
proposal := ProposalFromVrf(vrfSig)
logger.WithContext(ctx).With().Debug("calculated beacon proposal", epoch, nonce, log.Inline(proposal), signerID(signer.NodeID()))
logger.WithContext(ctx).With().Debug("calculated beacon proposal", epoch, nonce, log.Inline(proposal), log.ShortStringer("id", signer.NodeID()))
return vrfSig
}

Expand Down
16 changes: 8 additions & 8 deletions beacon/beacon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"go.uber.org/mock/gomock"

"github.com/spacemeshos/go-spacemesh/activation"
"github.com/spacemeshos/go-spacemesh/beacon/weakcoin"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/common/types/result"
"github.com/spacemeshos/go-spacemesh/datastore"
Expand All @@ -39,10 +40,9 @@ func coinValueMock(tb testing.TB, value bool) coin {
gomock.AssignableToTypeOf(types.EpochID(0)),
).AnyTimes()
coinMock.EXPECT().FinishEpoch(gomock.Any(), gomock.AssignableToTypeOf(types.EpochID(0))).AnyTimes()
nonce := types.VRFPostIndex(0)
coinMock.EXPECT().StartRound(gomock.Any(),
gomock.AssignableToTypeOf(types.RoundID(0)),
gomock.AssignableToTypeOf(&nonce),
gomock.AssignableToTypeOf([]weakcoin.Participant{}),
).AnyTimes()
coinMock.EXPECT().FinishRound(gomock.Any()).AnyTimes()
coinMock.EXPECT().Get(
Expand Down Expand Up @@ -94,7 +94,7 @@ func newTestDriver(tb testing.TB, cfg Config, p pubsub.Publisher, miners int, id
tpd.ProtocolDriver = New(p, signing.NewEdVerifier(), tpd.mVerifier, tpd.cdb, tpd.mClock,
WithConfig(cfg),
WithLogger(lg),
withWeakCoinFactory(func(signer *signing.EdSigner) coin { return coinValueMock(tb, true) }),
withWeakCoin(coinValueMock(tb, true)),
)
tpd.ProtocolDriver.SetSyncState(tpd.mSync)
for i := 0; i < miners; i++ {
Expand Down Expand Up @@ -192,7 +192,7 @@ func TestBeacon_MultipleNodes(t *testing.T) {

for _, db := range dbs {
for _, s := range node.signers {
createATX(t, db, atxPublishLid, s.signer, 1, time.Now().Add(-1*time.Second))
createATX(t, db, atxPublishLid, s, 1, time.Now().Add(-1*time.Second))
}
}
}
Expand Down Expand Up @@ -259,9 +259,9 @@ func TestBeacon_MultipleNodes_OnlyOneHonest(t *testing.T) {
for i, node := range testNodes {
for _, db := range dbs {
for _, s := range node.signers {
createATX(t, db, atxPublishLid, s.signer, 1, time.Now().Add(-1*time.Second))
createATX(t, db, atxPublishLid, s, 1, time.Now().Add(-1*time.Second))
if i != 0 {
require.NoError(t, identities.SetMalicious(db, s.signer.NodeID(), []byte("bad"), time.Now()))
require.NoError(t, identities.SetMalicious(db, s.NodeID(), []byte("bad"), time.Now()))
}
}
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func TestBeacon_NoProposals(t *testing.T) {
for _, node := range testNodes {
for _, db := range dbs {
for _, s := range node.signers {
createATX(t, db, atxPublishLid, s.signer, 1, time.Now().Add(-1*time.Second))
createATX(t, db, atxPublishLid, s, 1, time.Now().Add(-1*time.Second))
}
}
}
Expand Down Expand Up @@ -409,7 +409,7 @@ func TestBeaconWithMetrics(t *testing.T) {
for i := types.EpochID(2); i < epoch; i++ {
lid := i.FirstLayer().Sub(1)
for _, s := range tpd.signers {
createATX(t, tpd.cdb, lid, s.signer, 199, time.Now())
createATX(t, tpd.cdb, lid, s, 199, time.Now())
}
createRandomATXs(t, tpd.cdb, lid, 9)
}
Expand Down
14 changes: 1 addition & 13 deletions beacon/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"math/big"
"time"

"golang.org/x/sync/errgroup"

bcnmetrics "github.com/spacemeshos/go-spacemesh/beacon/metrics"
"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/common/types"
Expand Down Expand Up @@ -44,17 +42,7 @@ func (pd *ProtocolDriver) HandleWeakCoinProposal(ctx context.Context, peer p2p.P
if !pd.isInProtocol() {
return errBeaconProtocolInactive
}
var eg errgroup.Group
pd.mu.RLock()
for _, signer := range pd.signers {
coin := signer.coin
eg.Go(func() error {
return coin.HandleProposal(ctx, peer, msg)
})
}
pd.mu.Unlock()

return eg.Wait()
return pd.weakCoin.HandleProposal(ctx, peer, msg)
}

// HandleProposal handles beacon proposal from gossip.
Expand Down
3 changes: 2 additions & 1 deletion beacon/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"time"

"github.com/spacemeshos/go-spacemesh/beacon/weakcoin"
"github.com/spacemeshos/go-spacemesh/common/types"
"github.com/spacemeshos/go-spacemesh/p2p"
)
Expand All @@ -12,7 +13,7 @@ import (

type coin interface {
StartEpoch(context.Context, types.EpochID)
StartRound(context.Context, types.RoundID, *types.VRFPostIndex)
StartRound(context.Context, types.RoundID, []weakcoin.Participant)
FinishRound(context.Context)
Get(context.Context, types.EpochID, types.RoundID) (bool, error)
FinishEpoch(context.Context, types.EpochID)
Expand Down
Loading

0 comments on commit 7821cd7

Please sign in to comment.