diff --git a/consensus/reactor.go b/consensus/reactor.go index 1f508319d2e..eb7e6957454 100644 --- a/consensus/reactor.go +++ b/consensus/reactor.go @@ -1444,11 +1444,17 @@ func (m *NewValidBlockMessage) ValidateBasic() error { if err := m.BlockPartsHeader.ValidateBasic(); err != nil { return fmt.Errorf("Wrong BlockPartsHeader: %v", err) } + if m.BlockParts.Size() == 0 { + return errors.New("Empty BlockParts") + } if m.BlockParts.Size() != m.BlockPartsHeader.Total { return fmt.Errorf("BlockParts bit array size %d not equal to BlockPartsHeader.Total %d", m.BlockParts.Size(), m.BlockPartsHeader.Total) } + if m.BlockParts.Size() > types.MaxBlockPartsCount { + return errors.Errorf("BlockParts bit array is too big: %d, max: %d", m.BlockParts.Size(), types.MaxBlockPartsCount) + } return nil } @@ -1495,6 +1501,9 @@ func (m *ProposalPOLMessage) ValidateBasic() error { if m.ProposalPOL.Size() == 0 { return errors.New("Empty ProposalPOL bit array") } + if m.ProposalPOL.Size() > types.MaxVotesCount { + return errors.Errorf("ProposalPOL bit array is too big: %d, max: %d", m.ProposalPOL.Size(), types.MaxVotesCount) + } return nil } @@ -1638,6 +1647,9 @@ func (m *VoteSetBitsMessage) ValidateBasic() error { return fmt.Errorf("Wrong BlockID: %v", err) } // NOTE: Votes.Size() can be zero if the node does not have any + if m.Votes.Size() > types.MaxVotesCount { + return fmt.Errorf("Votes bit array is too big: %d, max: %d", m.Votes.Size(), types.MaxVotesCount) + } return nil } diff --git a/consensus/reactor_test.go b/consensus/reactor_test.go index be3f607e4f7..17e9cf762c7 100644 --- a/consensus/reactor_test.go +++ b/consensus/reactor_test.go @@ -19,6 +19,9 @@ import ( abci "github.com/tendermint/tendermint/abci/types" bc "github.com/tendermint/tendermint/blockchain" cfg "github.com/tendermint/tendermint/config" + cstypes "github.com/tendermint/tendermint/consensus/types" + "github.com/tendermint/tendermint/crypto/tmhash" + cmn "github.com/tendermint/tendermint/libs/common" dbm "github.com/tendermint/tendermint/libs/db" "github.com/tendermint/tendermint/libs/log" mempl "github.com/tendermint/tendermint/mempool" @@ -566,3 +569,269 @@ func capture() { count := runtime.Stack(trace, true) fmt.Printf("Stack of %d bytes: %s\n", count, trace) } + +//------------------------------------------------------------- +// Ensure basic validation of structs is functioning + +func TestNewRoundStepMessageValidateBasic(t *testing.T) { + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageLastCommitRound int + messageHeight int64 + testName string + messageStep cstypes.RoundStepType + }{ + {false, 0, 0, 0, "Valid Message", 0x01}, + {true, -1, 0, 0, "Invalid Message", 0x01}, + {true, 0, 0, -1, "Invalid Message", 0x01}, + {true, 0, 0, 1, "Invalid Message", 0x00}, + {true, 0, 0, 1, "Invalid Message", 0x00}, + {true, 0, -2, 2, "Invalid Message", 0x01}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := NewRoundStepMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Step: tc.messageStep, + LastCommitRound: tc.messageLastCommitRound, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestNewValidBlockMessageValidateBasic(t *testing.T) { + testCases := []struct { + malleateFn func(*NewValidBlockMessage) + expErr string + }{ + {func(msg *NewValidBlockMessage) {}, ""}, + {func(msg *NewValidBlockMessage) { msg.Height = -1 }, "Negative Height"}, + {func(msg *NewValidBlockMessage) { msg.Round = -1 }, "Negative Round"}, + { + func(msg *NewValidBlockMessage) { msg.BlockPartsHeader.Total = 2 }, + "BlockParts bit array size 1 not equal to BlockPartsHeader.Total 2", + }, + { + func(msg *NewValidBlockMessage) { msg.BlockPartsHeader.Total = 0; msg.BlockParts = cmn.NewBitArray(0) }, + "Empty BlockParts", + }, + { + func(msg *NewValidBlockMessage) { msg.BlockParts = cmn.NewBitArray(types.MaxBlockPartsCount + 1) }, + "BlockParts bit array size 1602 not equal to BlockPartsHeader.Total 1", + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + msg := &NewValidBlockMessage{ + Height: 1, + Round: 0, + BlockPartsHeader: types.PartSetHeader{ + Total: 1, + }, + BlockParts: cmn.NewBitArray(1), + } + + tc.malleateFn(msg) + err := msg.ValidateBasic() + if tc.expErr != "" && assert.Error(t, err) { + assert.Contains(t, err.Error(), tc.expErr) + } + }) + } +} + +func TestProposalPOLMessageValidateBasic(t *testing.T) { + testCases := []struct { + malleateFn func(*ProposalPOLMessage) + expErr string + }{ + {func(msg *ProposalPOLMessage) {}, ""}, + {func(msg *ProposalPOLMessage) { msg.Height = -1 }, "Negative Height"}, + {func(msg *ProposalPOLMessage) { msg.ProposalPOLRound = -1 }, "Negative ProposalPOLRound"}, + {func(msg *ProposalPOLMessage) { msg.ProposalPOL = cmn.NewBitArray(0) }, "Empty ProposalPOL bit array"}, + {func(msg *ProposalPOLMessage) { msg.ProposalPOL = cmn.NewBitArray(types.MaxVotesCount + 1) }, + "ProposalPOL bit array is too big: 10001, max: 10000"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + msg := &ProposalPOLMessage{ + Height: 1, + ProposalPOLRound: 1, + ProposalPOL: cmn.NewBitArray(1), + } + + tc.malleateFn(msg) + err := msg.ValidateBasic() + if tc.expErr != "" && assert.Error(t, err) { + assert.Contains(t, err.Error(), tc.expErr) + } + }) + } +} + +func TestBlockPartMessageValidateBasic(t *testing.T) { + testPart := new(types.Part) + testPart.Proof.LeafHash = tmhash.Sum([]byte("leaf")) + testCases := []struct { + testName string + messageHeight int64 + messageRound int + messagePart *types.Part + expectErr bool + }{ + {"Valid Message", 0, 0, testPart, false}, + {"Invalid Message", -1, 0, testPart, true}, + {"Invalid Message", 0, -1, testPart, true}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := BlockPartMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Part: tc.messagePart, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } + + message := BlockPartMessage{Height: 0, Round: 0, Part: new(types.Part)} + message.Part.Index = -1 + + assert.Equal(t, true, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") +} + +func TestHasVoteMessageValidateBasic(t *testing.T) { + const ( + validSignedMsgType types.SignedMsgType = 0x01 + invalidSignedMsgType types.SignedMsgType = 0x03 + ) + + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageIndex int + messageHeight int64 + testName string + messageType types.SignedMsgType + }{ + {false, 0, 0, 0, "Valid Message", validSignedMsgType}, + {true, -1, 0, 0, "Invalid Message", validSignedMsgType}, + {true, 0, -1, 0, "Invalid Message", validSignedMsgType}, + {true, 0, 0, 0, "Invalid Message", invalidSignedMsgType}, + {true, 0, 0, -1, "Invalid Message", validSignedMsgType}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := HasVoteMessage{ + Height: tc.messageHeight, + Round: tc.messageRound, + Type: tc.messageType, + Index: tc.messageIndex, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestVoteSetMaj23MessageValidateBasic(t *testing.T) { + const ( + validSignedMsgType types.SignedMsgType = 0x01 + invalidSignedMsgType types.SignedMsgType = 0x03 + ) + + validBlockID := types.BlockID{} + invalidBlockID := types.BlockID{ + Hash: cmn.HexBytes{}, + PartsHeader: types.PartSetHeader{ + Total: -1, + Hash: cmn.HexBytes{}, + }, + } + + testCases := []struct { // nolint: maligned + expectErr bool + messageRound int + messageHeight int64 + testName string + messageType types.SignedMsgType + messageBlockID types.BlockID + }{ + {false, 0, 0, "Valid Message", validSignedMsgType, validBlockID}, + {true, -1, 0, "Invalid Message", validSignedMsgType, validBlockID}, + {true, 0, -1, "Invalid Message", validSignedMsgType, validBlockID}, + {true, 0, 0, "Invalid Message", invalidSignedMsgType, validBlockID}, + {true, 0, 0, "Invalid Message", validSignedMsgType, invalidBlockID}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + message := VoteSetMaj23Message{ + Height: tc.messageHeight, + Round: tc.messageRound, + Type: tc.messageType, + BlockID: tc.messageBlockID, + } + + assert.Equal(t, tc.expectErr, message.ValidateBasic() != nil, "Validate Basic had an unexpected result") + }) + } +} + +func TestVoteSetBitsMessageValidateBasic(t *testing.T) { + testCases := []struct { // nolint: maligned + malleateFn func(*VoteSetBitsMessage) + expErr string + }{ + {func(msg *VoteSetBitsMessage) {}, ""}, + {func(msg *VoteSetBitsMessage) { msg.Height = -1 }, "Negative Height"}, + {func(msg *VoteSetBitsMessage) { msg.Round = -1 }, "Negative Round"}, + {func(msg *VoteSetBitsMessage) { msg.Type = 0x03 }, "Invalid Type"}, + {func(msg *VoteSetBitsMessage) { + msg.BlockID = types.BlockID{ + Hash: cmn.HexBytes{}, + PartsHeader: types.PartSetHeader{ + Total: -1, + Hash: cmn.HexBytes{}, + }, + } + }, "Wrong BlockID: Wrong PartsHeader: Negative Total"}, + {func(msg *VoteSetBitsMessage) { msg.Votes = cmn.NewBitArray(types.MaxVotesCount + 1) }, + "Votes bit array is too big: 10001, max: 10000"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + msg := &VoteSetBitsMessage{ + Height: 1, + Round: 0, + Type: 0x01, + Votes: cmn.NewBitArray(1), + BlockID: types.BlockID{}, + } + + tc.malleateFn(msg) + err := msg.ValidateBasic() + if tc.expErr != "" && assert.Error(t, err) { + assert.Contains(t, err.Error(), tc.expErr) + } + }) + } +} diff --git a/consensus/replay_test.go b/consensus/replay_test.go index 4233ef44e61..4a10dda8763 100644 --- a/consensus/replay_test.go +++ b/consensus/replay_test.go @@ -212,15 +212,15 @@ func (e ReachedHeightToStopError) Error() string { // Write simulate WAL's crashing by sending an error to the panicCh and then // exiting the cs.receiveRoutine. -func (w *crashingWAL) Write(m WALMessage) { +func (w *crashingWAL) Write(m WALMessage) error { if endMsg, ok := m.(EndHeightMessage); ok { if endMsg.Height == w.heightToStop { w.panicCh <- ReachedHeightToStopError{endMsg.Height} runtime.Goexit() - } else { - w.next.Write(m) + return nil } - return + + return w.next.Write(m) } if w.msgIndex > w.lastPanicedForMsgIndex { @@ -228,14 +228,15 @@ func (w *crashingWAL) Write(m WALMessage) { _, file, line, _ := runtime.Caller(1) w.panicCh <- WALWriteError{fmt.Sprintf("failed to write %T to WAL (fileline: %s:%d)", m, file, line)} runtime.Goexit() - } else { - w.msgIndex++ - w.next.Write(m) + return nil } + + w.msgIndex++ + return w.next.Write(m) } -func (w *crashingWAL) WriteSync(m WALMessage) { - w.Write(m) +func (w *crashingWAL) WriteSync(m WALMessage) error { + return w.Write(m) } func (w *crashingWAL) Group() *auto.Group { return w.next.Group() } diff --git a/consensus/state.go b/consensus/state.go index 58bcb97622c..6e0f3b8fecd 100644 --- a/consensus/state.go +++ b/consensus/state.go @@ -636,7 +636,10 @@ func (cs *ConsensusState) receiveRoutine(maxSteps int) { // may generate internal events (votes, complete proposals, 2/3 majorities) cs.handleMsg(mi) case mi = <-cs.internalMsgQueue: - cs.wal.WriteSync(mi) // NOTE: fsync + err := cs.wal.WriteSync(mi) // NOTE: fsync + if err != nil { + panic(fmt.Sprintf("Failed to write %v msg to consensus wal due to %v. Check your FS and restart the node", mi, err)) + } if _, ok := mi.Msg.(*VoteMessage); ok { // we actually want to simulate failing during @@ -1318,7 +1321,10 @@ func (cs *ConsensusState) finalizeCommit(height int64) { // Either way, the ConsensusState should not be resumed until we // successfully call ApplyBlock (ie. later here, or in Handshake after // restart). - cs.wal.WriteSync(EndHeightMessage{height}) // NOTE: fsync + me := EndHeightMessage{height} + if err := cs.wal.WriteSync(me); err != nil { // NOTE: fsync + panic(fmt.Sprintf("Failed to write %v msg to consensus wal due to %v. Check your FS and restart the node", me, err)) + } fail.Fail() // XXX diff --git a/consensus/wal.go b/consensus/wal.go index 26428a4c626..b817f86d926 100644 --- a/consensus/wal.go +++ b/consensus/wal.go @@ -19,15 +19,20 @@ import ( ) const ( - // must be greater than types.BlockPartSizeBytes + a few bytes - maxMsgSizeBytes = 1024 * 1024 // 1MB + + // amino overhead + time.Time + max consensus msg size + maxMsgSizeBytes = maxMsgSize + 24 + + // how often the WAL should be sync'd during period sync'ing + walDefaultFlushInterval = 2 * time.Second ) //-------------------------------------------------------- // types and functions for savings consensus messages +// TimedWALMessage wraps WALMessage and adds Time for debugging purposes. type TimedWALMessage struct { - Time time.Time `json:"time"` // for debugging purposes + Time time.Time `json:"time"` Msg WALMessage `json:"msg"` } @@ -52,20 +57,22 @@ func RegisterWALMessages(cdc *amino.Codec) { // WAL is an interface for any write-ahead logger. type WAL interface { - Write(WALMessage) - WriteSync(WALMessage) + Write(WALMessage) error + WriteSync(WALMessage) error Group() *auto.Group SearchForEndHeight(height int64, options *WALSearchOptions) (gr *auto.GroupReader, found bool, err error) + // service methods Start() error Stop() error Wait() } // Write ahead logger writes msgs to disk before they are processed. -// Can be used for crash-recovery and deterministic replay -// TODO: currently the wal is overwritten during replay catchup -// give it a mode so it's either reading or appending - must read to end to start appending again +// Can be used for crash-recovery and deterministic replay. +// TODO: currently the wal is overwritten during replay catchup, give it a mode +// so it's either reading or appending - must read to end to start appending +// again. type baseWAL struct { cmn.BaseService @@ -130,29 +137,39 @@ func (wal *baseWAL) Wait() { // Write is called in newStep and for each receive on the // peerMsgQueue and the timeoutTicker. // NOTE: does not call fsync() -func (wal *baseWAL) Write(msg WALMessage) { +func (wal *baseWAL) Write(msg WALMessage) error { if wal == nil { - return + return nil } - // Write the wal message if err := wal.enc.Encode(&TimedWALMessage{tmtime.Now(), msg}); err != nil { - panic(fmt.Sprintf("Error writing msg to consensus wal: %v \n\nMessage: %v", err, msg)) + wal.Logger.Error("Error writing msg to consensus wal. WARNING: recover may not be possible for the current height", + "err", err, "msg", msg) + return err } + + return nil } // WriteSync is called when we receive a msg from ourselves // so that we write to disk before sending signed messages. // NOTE: calls fsync() -func (wal *baseWAL) WriteSync(msg WALMessage) { +func (wal *baseWAL) WriteSync(msg WALMessage) error { if wal == nil { - return + return nil + } + + if err := wal.Write(msg); err != nil { + return err } - wal.Write(msg) if err := wal.group.Flush(); err != nil { - panic(fmt.Sprintf("Error flushing consensus wal buf to file. Error: %v \n", err)) + wal.Logger.Error("WriteSync failed to flush consensus wal. WARNING: may result in creating alternative proposals / votes for the current height iff the node restarted", + "err", err) + return err } + + return nil } // WALSearchOptions are optional arguments to SearchForEndHeight. @@ -229,12 +246,17 @@ func NewWALEncoder(wr io.Writer) *WALEncoder { return &WALEncoder{wr} } -// Encode writes the custom encoding of v to the stream. +// Encode writes the custom encoding of v to the stream. It returns an error if +// the amino-encoded size of v is greater than 1MB. Any error encountered +// during the write is also returned. func (enc *WALEncoder) Encode(v *TimedWALMessage) error { data := cdc.MustMarshalBinaryBare(v) crc := crc32.Checksum(data, crc32c) length := uint32(len(data)) + if length > maxMsgSizeBytes { + return fmt.Errorf("msg is too big: %d bytes, max: %d bytes", length, maxMsgSizeBytes) + } totalLength := 8 + int(length) msg := make([]byte, totalLength) @@ -243,7 +265,6 @@ func (enc *WALEncoder) Encode(v *TimedWALMessage) error { copy(msg[8:], data) _, err := enc.wr.Write(msg) - return err } @@ -307,15 +328,15 @@ func (dec *WALDecoder) Decode() (*TimedWALMessage, error) { } data := make([]byte, length) - _, err = dec.rd.Read(data) + n, err := dec.rd.Read(data) if err != nil { - return nil, DataCorruptionError{fmt.Errorf("failed to read data: %v", err)} + return nil, DataCorruptionError{fmt.Errorf("failed to read data: %v (read: %d, wanted: %d)", err, n, length)} } // check checksum before decoding data actualCRC := crc32.Checksum(data, crc32c) if actualCRC != crc { - return nil, DataCorruptionError{fmt.Errorf("checksums do not match: (read: %v, actual: %v)", crc, actualCRC)} + return nil, DataCorruptionError{fmt.Errorf("checksums do not match: read: %v, actual: %v", crc, actualCRC)} } var res = new(TimedWALMessage) // nolint: gosimple @@ -329,9 +350,9 @@ func (dec *WALDecoder) Decode() (*TimedWALMessage, error) { type nilWAL struct{} -func (nilWAL) Write(m WALMessage) {} -func (nilWAL) WriteSync(m WALMessage) {} -func (nilWAL) Group() *auto.Group { return nil } +func (nilWAL) Write(m WALMessage) error { return nil } +func (nilWAL) WriteSync(m WALMessage) error { return nil } +func (nilWAL) Group() *auto.Group { return nil } func (nilWAL) SearchForEndHeight(height int64, options *WALSearchOptions) (gr *auto.GroupReader, found bool, err error) { return nil, false, nil } diff --git a/consensus/wal_generator.go b/consensus/wal_generator.go index 5ff597a52e2..4f546ebb85a 100644 --- a/consensus/wal_generator.go +++ b/consensus/wal_generator.go @@ -178,10 +178,10 @@ func newByteBufferWAL(logger log.Logger, enc *WALEncoder, nBlocks int64, signalS // Save writes message to the internal buffer except when heightToStop is // reached, in which case it will signal the caller via signalWhenStopsTo and // skip writing. -func (w *byteBufferWAL) Write(m WALMessage) { +func (w *byteBufferWAL) Write(m WALMessage) error { if w.stopped { w.logger.Debug("WAL already stopped. Not writing message", "msg", m) - return + return nil } if endMsg, ok := m.(EndHeightMessage); ok { @@ -190,7 +190,7 @@ func (w *byteBufferWAL) Write(m WALMessage) { w.logger.Debug("Stopping WAL at height", "height", endMsg.Height) w.signalWhenStopsTo <- struct{}{} w.stopped = true - return + return nil } } @@ -199,10 +199,12 @@ func (w *byteBufferWAL) Write(m WALMessage) { if err != nil { panic(fmt.Sprintf("failed to encode the msg %v", m)) } + + return nil } -func (w *byteBufferWAL) WriteSync(m WALMessage) { - w.Write(m) +func (w *byteBufferWAL) WriteSync(m WALMessage) error { + return w.Write(m) } func (w *byteBufferWAL) Group() *auto.Group { diff --git a/consensus/wal_test.go b/consensus/wal_test.go index 54219264f0f..daed449c21a 100644 --- a/consensus/wal_test.go +++ b/consensus/wal_test.go @@ -3,22 +3,27 @@ package consensus import ( "bytes" "crypto/rand" - "fmt" "io/ioutil" "os" "path/filepath" + // "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/consensus/types" + "github.com/tendermint/tendermint/crypto/merkle" "github.com/tendermint/tendermint/libs/autofile" "github.com/tendermint/tendermint/libs/log" tmtypes "github.com/tendermint/tendermint/types" tmtime "github.com/tendermint/tendermint/types/time" +) - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" +const ( + walTestFlushInterval = time.Duration(100) * time.Millisecond ) func TestWALTruncate(t *testing.T) { @@ -28,8 +33,10 @@ func TestWALTruncate(t *testing.T) { walFile := filepath.Join(walDir, "wal") - //this magic number 4K can truncate the content when RotateFile. defaultHeadSizeLimit(10M) is hard to simulate. - //this magic number 1 * time.Millisecond make RotateFile check frequently. defaultGroupCheckDuration(5s) is hard to simulate. + // this magic number 4K can truncate the content when RotateFile. + // defaultHeadSizeLimit(10M) is hard to simulate. + // this magic number 1 * time.Millisecond make RotateFile check frequently. + // defaultGroupCheckDuration(5s) is hard to simulate. wal, err := NewWAL(walFile, autofile.GroupHeadSizeLimit(4096), autofile.GroupCheckDuration(1*time.Millisecond), @@ -45,20 +52,21 @@ func TestWALTruncate(t *testing.T) { wal.Wait() }() - //60 block's size nearly 70K, greater than group's headBuf size(4096 * 10), when headBuf is full, truncate content will Flush to the file. - //at this time, RotateFile is called, truncate content exist in each file. + // 60 block's size nearly 70K, greater than group's headBuf size(4096 * 10), + // when headBuf is full, truncate content will Flush to the file. at this + // time, RotateFile is called, truncate content exist in each file. err = WALGenerateNBlocks(wal.Group(), 60) require.NoError(t, err) time.Sleep(1 * time.Millisecond) //wait groupCheckDuration, make sure RotateFile run - wal.Group().Flush() + wal.group.Flush() h := int64(50) gr, found, err := wal.SearchForEndHeight(h, &WALSearchOptions{}) - assert.NoError(t, err, fmt.Sprintf("expected not to err on height %d", h)) - assert.True(t, found, fmt.Sprintf("expected to find end height for %d", h)) - assert.NotNil(t, gr, "expected group not to be nil") + assert.NoError(t, err, "expected not to err on height %d", h) + assert.True(t, found, "expected to find end height for %d", h) + assert.NotNil(t, gr) defer gr.Close() dec := NewWALDecoder(gr) @@ -66,14 +74,14 @@ func TestWALTruncate(t *testing.T) { assert.NoError(t, err, "expected to decode a message") rs, ok := msg.Msg.(tmtypes.EventDataRoundState) assert.True(t, ok, "expected message of type EventDataRoundState") - assert.Equal(t, rs.Height, h+1, fmt.Sprintf("wrong height")) + assert.Equal(t, rs.Height, h+1, "wrong height") } func TestWALEncoderDecoder(t *testing.T) { now := tmtime.Now() msgs := []TimedWALMessage{ - TimedWALMessage{Time: now, Msg: EndHeightMessage{0}}, - TimedWALMessage{Time: now, Msg: timeoutInfo{Duration: time.Second, Height: 1, Round: 1, Step: types.RoundStepPropose}}, + {Time: now, Msg: EndHeightMessage{0}}, + {Time: now, Msg: timeoutInfo{Duration: time.Second, Height: 1, Round: 1, Step: types.RoundStepPropose}}, } b := new(bytes.Buffer) @@ -94,6 +102,43 @@ func TestWALEncoderDecoder(t *testing.T) { } } +func TestWALWrite(t *testing.T) { + walDir, err := ioutil.TempDir("", "wal") + require.NoError(t, err) + defer os.RemoveAll(walDir) + walFile := filepath.Join(walDir, "wal") + + wal, err := NewWAL(walFile) + require.NoError(t, err) + err = wal.Start() + require.NoError(t, err) + defer func() { + wal.Stop() + // wait for the wal to finish shutting down so we + // can safely remove the directory + wal.Wait() + }() + + // 1) Write returns an error if msg is too big + msg := &BlockPartMessage{ + Height: 1, + Round: 1, + Part: &tmtypes.Part{ + Index: 1, + Bytes: make([]byte, 1), + Proof: merkle.SimpleProof{ + Total: 1, + Index: 1, + LeafHash: make([]byte, maxMsgSizeBytes-30), + }, + }, + } + err = wal.Write(msg) + if assert.Error(t, err) { + assert.Equal(t, "msg is too big: 1048593 bytes, max: 1048576 bytes", err.Error()) + } +} + func TestWALSearchForEndHeight(t *testing.T) { walBody, err := WALWithNBlocks(6) if err != nil { @@ -107,9 +152,9 @@ func TestWALSearchForEndHeight(t *testing.T) { h := int64(3) gr, found, err := wal.SearchForEndHeight(h, &WALSearchOptions{}) - assert.NoError(t, err, fmt.Sprintf("expected not to err on height %d", h)) - assert.True(t, found, fmt.Sprintf("expected to find end height for %d", h)) - assert.NotNil(t, gr, "expected group not to be nil") + assert.NoError(t, err, "expected not to err on height %d", h) + assert.True(t, found, "expected to find end height for %d", h) + assert.NotNil(t, gr) defer gr.Close() dec := NewWALDecoder(gr) @@ -117,9 +162,51 @@ func TestWALSearchForEndHeight(t *testing.T) { assert.NoError(t, err, "expected to decode a message") rs, ok := msg.Msg.(tmtypes.EventDataRoundState) assert.True(t, ok, "expected message of type EventDataRoundState") - assert.Equal(t, rs.Height, h+1, fmt.Sprintf("wrong height")) + assert.Equal(t, rs.Height, h+1, "wrong height") } +/* +// TODO: cherry pick FlushInterval +func TestWALPeriodicSync(t *testing.T) { + walDir, err := ioutil.TempDir("", "wal") + require.NoError(t, err) + defer os.RemoveAll(walDir) + + walFile := filepath.Join(walDir, "wal") + wal, err := NewWAL(walFile, autofile.GroupCheckDuration(1*time.Millisecond)) + require.NoError(t, err) + + wal.SetFlushInterval(walTestFlushInterval) + wal.SetLogger(log.TestingLogger()) + + // Generate some data + err = WALGenerateNBlocks(t, wal.Group(), 5) + require.NoError(t, err) + + // We should have data in the buffer now + assert.NotZero(t, wal.Group().Buffered()) + + require.NoError(t, wal.Start()) + defer func() { + wal.Stop() + wal.Wait() + }() + + time.Sleep(walTestFlushInterval + (10 * time.Millisecond)) + + // The data should have been flushed by the periodic sync + assert.Zero(t, wal.Group().Buffered()) + + h := int64(4) + gr, found, err := wal.SearchForEndHeight(h, &WALSearchOptions{}) + assert.NoError(t, err, "expected not to err on height %d", h) + assert.True(t, found, "expected to find end height for %d", h) + assert.NotNil(t, gr) + if gr != nil { + gr.Close() + } +}*/ + /* var initOnce sync.Once diff --git a/crypto/merkle/simple_proof.go b/crypto/merkle/simple_proof.go index fd6d07b88c9..f0d7188276f 100644 --- a/crypto/merkle/simple_proof.go +++ b/crypto/merkle/simple_proof.go @@ -2,13 +2,19 @@ package merkle import ( "bytes" - "errors" "fmt" + "github.com/pkg/errors" + "github.com/tendermint/tendermint/crypto/tmhash" cmn "github.com/tendermint/tendermint/libs/common" ) +const ( + // given maxMsgSizeBytes in consensus wal is 1MB + maxAunts = 30000 +) + // SimpleProof represents a simple Merkle proof. // NOTE: The convention for proofs is to include leaf hashes but to // exclude the root hash. @@ -109,6 +115,29 @@ func (sp *SimpleProof) StringIndented(indent string) string { indent) } +// ValidateBasic performs basic validation. +// NOTE: it expects LeafHash and Aunts of tmhash.Size size. +func (sp *SimpleProof) ValidateBasic() error { + if sp.Total < 0 { + return errors.New("negative Total") + } + if sp.Index < 0 { + return errors.New("negative Index") + } + if len(sp.LeafHash) != tmhash.Size { + return errors.Errorf("expected LeafHash size to be %d, got %d", tmhash.Size, len(sp.LeafHash)) + } + if len(sp.Aunts) > maxAunts { + return errors.Errorf("expected no more than %d aunts, got %d", maxAunts, len(sp.Aunts)) + } + for i, auntHash := range sp.Aunts { + if len(auntHash) != tmhash.Size { + return errors.Errorf("expected Aunts#%d size to be %d, got %d", i, tmhash.Size, len(auntHash)) + } + } + return nil +} + // Use the leafHash and innerHashes to get the root merkle hash. // If the length of the innerHashes slice isn't exactly correct, the result is nil. // Recursive impl. diff --git a/crypto/merkle/simple_proof_test.go b/crypto/merkle/simple_proof_test.go new file mode 100644 index 00000000000..1a517905b5e --- /dev/null +++ b/crypto/merkle/simple_proof_test.go @@ -0,0 +1,38 @@ +package merkle + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleProofValidateBasic(t *testing.T) { + testCases := []struct { + testName string + malleateProof func(*SimpleProof) + errStr string + }{ + {"Good", func(sp *SimpleProof) {}, ""}, + {"Negative Total", func(sp *SimpleProof) { sp.Total = -1 }, "negative Total"}, + {"Negative Index", func(sp *SimpleProof) { sp.Index = -1 }, "negative Index"}, + {"Invalid LeafHash", func(sp *SimpleProof) { sp.LeafHash = make([]byte, 10) }, "expected LeafHash size to be 32, got 10"}, + {"Too many Aunts", func(sp *SimpleProof) { sp.Aunts = make([][]byte, maxAunts+1) }, "expected no more than 100 aunts, got 101"}, + {"Invalid Aunt", func(sp *SimpleProof) { sp.Aunts[0] = make([]byte, 10) }, "expected Aunts#0 size to be 32, got 10"}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + _, proofs := SimpleProofsFromByteSlices([][]byte{ + []byte("apple"), + []byte("watermelon"), + []byte("kiwi"), + }) + tc.malleateProof(proofs[0]) + err := proofs[0].ValidateBasic() + if tc.errStr != "" { + assert.Contains(t, err.Error(), tc.errStr) + } + }) + } +} diff --git a/types/params.go b/types/params.go index ec8a8f576bb..f49bf1be057 100644 --- a/types/params.go +++ b/types/params.go @@ -12,6 +12,9 @@ const ( // BlockPartSizeBytes is the size of one block part. BlockPartSizeBytes = 65536 // 64kB + + // MaxBlockPartsCount is the maximum count of block parts. + MaxBlockPartsCount = MaxBlockSizeBytes / BlockPartSizeBytes ) // ConsensusParams contains consensus critical parameters that determine the diff --git a/types/part_set.go b/types/part_set.go index a040258d1cb..4232e212a96 100644 --- a/types/part_set.go +++ b/types/part_set.go @@ -40,10 +40,13 @@ func (part *Part) Hash() []byte { // ValidateBasic performs basic validation. func (part *Part) ValidateBasic() error { if part.Index < 0 { - return errors.New("Negative Index") + return errors.New("negative Index") } if len(part.Bytes) > BlockPartSizeBytes { - return fmt.Errorf("Too big (max: %d)", BlockPartSizeBytes) + return errors.Errorf("too big: %d bytes, max: %d", len(part.Bytes), BlockPartSizeBytes) + } + if err := part.Proof.ValidateBasic(); err != nil { + return errors.Wrap(err, "wrong Proof") } return nil } diff --git a/types/part_set_test.go b/types/part_set_test.go index daa2fa5c5d5..5c0edaffd4b 100644 --- a/types/part_set_test.go +++ b/types/part_set_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/crypto/merkle" cmn "github.com/tendermint/tendermint/libs/common" ) @@ -114,6 +115,13 @@ func TestPartValidateBasic(t *testing.T) { {"Good Part", func(pt *Part) {}, false}, {"Negative index", func(pt *Part) { pt.Index = -1 }, true}, {"Too big part", func(pt *Part) { pt.Bytes = make([]byte, BlockPartSizeBytes+1) }, true}, + {"Too big proof", func(pt *Part) { + pt.Proof = merkle.SimpleProof{ + Total: 1, + Index: 1, + LeafHash: make([]byte, 1024*1024), + } + }, true}, } for _, tc := range testCases { diff --git a/types/vote_set.go b/types/vote_set.go index 0cf6cbb7f5d..fdaf45ec235 100644 --- a/types/vote_set.go +++ b/types/vote_set.go @@ -11,6 +11,12 @@ import ( cmn "github.com/tendermint/tendermint/libs/common" ) +const ( + // MaxVotesCount is the maximum votes count. Used in ValidateBasic funcs for + // protection against DOS attacks. + MaxVotesCount = 10000 +) + // UNSTABLE // XXX: duplicate of p2p.ID to avoid dependence between packages. // Perhaps we can have a minimal types package containing this (and other things?) diff --git a/version/version.go b/version/version.go index 98fda02d3a2..739700656ba 100644 --- a/version/version.go +++ b/version/version.go @@ -18,7 +18,7 @@ const ( // TMCoreSemVer is the current version of Tendermint Core. // It's the Semantic Version of the software. // Must be a string because scripts like dist.sh read this file. - TMCoreSemVer = "0.31.2" + TMCoreSemVer = "0.31.3" // ABCISemVer is the semantic version of the ABCI library ABCISemVer = "0.15.0"