diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 254b1f7fb4..ddf68bff4c 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -19,8 +19,11 @@ package abi import ( "bytes" "encoding/json" + "errors" "fmt" "io" + + "github.com/tomochain/tomochain/crypto" ) // The ABI holds information about a contract's context and available @@ -144,3 +147,26 @@ func (abi *ABI) MethodById(sigdata []byte) (*Method, error) { } return nil, fmt.Errorf("no method with id: %#x", sigdata[:4]) } + +// revertSelector is a special function selector for revert reason unpacking. +var revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] + +// UnpackRevert resolves the abi-encoded revert reason. According to the solidity +// spec https://solidity.readthedocs.io/en/latest/control-structures.html#revert, +// the provided revert reason is abi-encoded as if it were a call to a function +// `Error(string)`. So it's a special tool for it. +func UnpackRevert(data []byte) (string, error) { + if len(data) < 4 { + return "", errors.New("invalid data for unpacking") + } + if !bytes.Equal(data[:4], revertSelector) { + return "", errors.New("invalid data for unpacking") + } + var reason string + // typ, _ := NewType("string", "", nil) + typ, _ := NewType("string") + if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { + return "", err + } + return reason, nil +} diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 5a128bfe54..9092b6cd81 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -19,6 +19,7 @@ package abi import ( "bytes" "encoding/hex" + "errors" "fmt" "log" "math/big" @@ -619,16 +620,19 @@ func TestBareEvents(t *testing.T) { } // TestUnpackEvent is based on this contract: -// contract T { -// event received(address sender, uint amount, bytes memo); -// event receivedAddr(address sender); -// function receive(bytes memo) external payable { -// received(msg.sender, msg.value, memo); -// receivedAddr(msg.sender); -// } -// } +// +// contract T { +// event received(address sender, uint amount, bytes memo); +// event receivedAddr(address sender); +// function receive(bytes memo) external payable { +// received(msg.sender, msg.value, memo); +// receivedAddr(msg.sender); +// } +// } +// // When receive("X") is called with sender 0x00... and value 1, it produces this tx receipt: -// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} +// +// receipt{status=1 cgas=23949 bloom=00000000004000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000040200000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 logs=[log: b6818c8064f645cd82d99b59a1a267d6d61117ef [75fd880d39c1daf53b6547ab6cb59451fc6452d27caa90e5b6649dd8293b9eed] 000000000000000000000000376c47978271565f56deb45495afa69e59c16ab200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000000158 9ae378b6d4409eada347a5dc0c180f186cb62dc68fcc0f043425eb917335aa28 0 95d429d309bb9d753954195fe2d69bd140b4ae731b9b5b605c34323de162cf00 0]} func TestUnpackEvent(t *testing.T) { const abiJSON = `[{"constant":false,"inputs":[{"name":"memo","type":"bytes"}],"name":"receive","outputs":[],"payable":true,"stateMutability":"payable","type":"function"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"},{"indexed":false,"name":"amount","type":"uint256"},{"indexed":false,"name":"memo","type":"bytes"}],"name":"received","type":"event"},{"anonymous":false,"inputs":[{"indexed":false,"name":"sender","type":"address"}],"name":"receivedAddr","type":"event"}]` abi, err := JSON(strings.NewReader(abiJSON)) @@ -713,3 +717,34 @@ func TestABI_MethodById(t *testing.T) { } } + +func TestUnpackRevert(t *testing.T) { + t.Parallel() + + var cases = []struct { + input string + expect string + expectErr error + }{ + {"", "", errors.New("invalid data for unpacking")}, + {"08c379a1", "", errors.New("invalid data for unpacking")}, + {"08c379a00000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000d72657665727420726561736f6e00000000000000000000000000000000000000", "revert reason", nil}, + } + for index, c := range cases { + t.Run(fmt.Sprintf("case %d", index), func(t *testing.T) { + got, err := UnpackRevert(common.Hex2Bytes(c.input)) + if c.expectErr != nil { + if err == nil { + t.Fatalf("Expected non-nil error") + } + if err.Error() != c.expectErr.Error() { + t.Fatalf("Expected error mismatch, want %v, got %v", c.expectErr, err) + } + return + } + if c.expect != got { + t.Fatalf("Output mismatch, want %v, got %v", c.expect, got) + } + }) + } +} diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7411f492a8..d204c76d86 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -20,19 +20,21 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "time" "github.com/tomochain/tomochain" + "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -46,8 +48,11 @@ import ( // This nil assignment ensures compile time that SimulatedBackend implements bind.ContractBackend. var _ bind.ContractBackend = (*SimulatedBackend)(nil) -var errBlockNumberUnsupported = errors.New("SimulatedBackend cannot access blocks other than the latest block") -var errGasEstimationFailed = errors.New("gas required exceeds allowance or always failing transaction") +var ( + errBlockNumberUnsupported = errors.New("SimulatedBackend cannot access blocks other than the latest block") + errBlockDoesNotExist = errors.New("block does not exist in blockchain") + errGasEstimationFailed = errors.New("gas required exceeds allowance or always failing transaction") +) // SimulatedBackend implements bind.ContractBackend, simulating a blockchain in // the background. Its main purpose is to allow easily testing contract bindings. @@ -107,7 +112,7 @@ func (b *SimulatedBackend) rollback() { statedb, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -174,7 +179,7 @@ func (b *SimulatedBackend) ForEachStorageAt(ctx context.Context, contract common // TransactionReceipt returns the receipt of a transaction. func (b *SimulatedBackend) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { - receipt, _, _, _ := core.GetReceipt(b.database, txHash) + receipt, _, _, _ := rawdb.GetReceipt(b.database, txHash, b.config) return receipt, nil } @@ -186,6 +191,58 @@ func (b *SimulatedBackend) PendingCodeAt(ctx context.Context, contract common.Ad return b.pendingState.GetCode(contract), nil } +// BlockByHash retrieves a block based on the block hash. +func (b *SimulatedBackend) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) { + b.mu.Lock() + defer b.mu.Unlock() + + return b.blockByHash(ctx, hash) +} + +// blockByHash retrieves a block based on the block hash without Locking. +func (b *SimulatedBackend) blockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) { + if hash == b.pendingBlock.Hash() { + return b.pendingBlock, nil + } + + block := b.blockchain.GetBlockByHash(hash) + if block != nil { + return block, nil + } + + return nil, errBlockDoesNotExist +} + +func newRevertError(result *core.ExecutionResult) *revertError { + reason, errUnpack := abi.UnpackRevert(result.Revert()) + err := errors.New("execution reverted") + if errUnpack == nil { + err = fmt.Errorf("execution reverted: %v", reason) + } + return &revertError{ + error: err, + reason: hexutil.Encode(result.Revert()), + } +} + +// revertError is an API error that encompassas an EVM revertal with JSON error +// code and a binary data blob. +type revertError struct { + error + reason string // revert reason hex encoded +} + +// ErrorCode returns the JSON error code for a revertal. +// See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal +func (e *revertError) ErrorCode() int { + return 3 +} + +// ErrorData returns the hex encoded revert reason. +func (e *revertError) ErrorData() interface{} { + return e.reason +} + // CallContract executes a contract call. func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.CallMsg, blockNumber *big.Int) ([]byte, error) { b.mu.Lock() @@ -198,11 +255,19 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call tomochain.Call if err != nil { return nil, err } - rval, _, _, err := b.callContract(ctx, call, b.blockchain.CurrentBlock(), state) - return rval, err + res, err := b.callContract(ctx, call, b.blockchain.CurrentBlock(), state) + if err != nil { + return nil, err + } + + if len(res.Revert()) > 0 { + return nil, newRevertError(res) + } + + return res.Return(), res.Err } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -215,11 +280,19 @@ func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain c call.Value = new(big.Int) } // Execute the call. - msg := callmsg{call} + msg := &core.Message{ + To: call.To, + From: call.From, + Value: call.Value, + GasLimit: call.Gas, + GasPrice: call.GasPrice, + Data: call.Data, + SkipAccountChecks: false, + } feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) - if msg.To() != nil { - if value, ok := feeCapacity[*msg.To()]; ok { - msg.CallMsg.BalanceTokenFee = value + if msg.To != nil { + if value, ok := feeCapacity[*msg.To]; ok { + msg.BalanceTokenFee = value } } evmContext := core.NewEVMContext(msg, chain.CurrentHeader(), chain, nil) @@ -228,11 +301,11 @@ func (b *SimulatedBackend) CallContractWithState(call tomochain.CallMsg, chain c vmenv := vm.NewEVM(evmContext, statedb, nil, chain.Config(), vm.Config{}) gaspool := new(core.GasPool).AddGas(1000000) owner := common.Address{} - rval, _, _, err := core.NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) + result, err := core.NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) if err != nil { return nil, err } - return rval, err + return result.Return(), nil } // PendingCallContract executes a contract call on the pending state. @@ -241,8 +314,15 @@ func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call tomocha defer b.mu.Unlock() defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - rval, _, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - return rval, err + res, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) + if err != nil { + return nil, err + } + if len(res.Revert()) > 0 { + return nil, newRevertError(res) + } + + return res.Return(), res.Err } // PendingNonceAt implements PendingStateReader.PendingNonceAt, retrieving @@ -280,23 +360,32 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM cap = hi // Create a helper to check if a gas allowance results in an executable transaction - executable := func(gas uint64) bool { + executable := func(gas uint64) (bool, *core.ExecutionResult, error) { call.Gas = gas snapshot := b.pendingState.Snapshot() - _, _, failed, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) - fmt.Println("EstimateGas",err,failed) + res, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) b.pendingState.RevertToSnapshot(snapshot) - if err != nil || failed { - return false + if err != nil { + if err == core.ErrIntrinsicGas { + return true, nil, nil // Special case, raise gas limit + } + return true, nil, err } - return true + return res.Failed(), res, nil } // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - if !executable(mid) { + failed, _, err := executable(mid) + // If the error is not nil(consensus error), it means the provided message + // call or transaction will never be accepted no matter how much gas it is + // assigned. Return the error directly, don't struggle any more + if err != nil { + return 0, err + } + if failed { lo = mid } else { hi = mid @@ -304,8 +393,21 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - if !executable(hi) { - return 0, errGasEstimationFailed + failed, result, err := executable(hi) + if err != nil { + return 0, err + } + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + + if len(result.Revert()) > 0 { + return 0, newRevertError(result) + } + return 0, result.Err + } + + // Otherwise, the specified gas cap is too low + return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) } } return hi, nil @@ -313,7 +415,7 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call tomochain.CallM // callContract implements common code between normal and pending contract calls. // state is modified during execution, make sure to copy it if necessary. -func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.CallMsg, block *types.Block, statedb *state.StateDB) ([]byte, uint64, bool, error) { +func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.CallMsg, block *types.Block, statedb *state.StateDB) (*core.ExecutionResult, error) { // Ensure message is initialized properly. if call.GasPrice == nil { call.GasPrice = big.NewInt(1) @@ -328,11 +430,19 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call tomochain.Call from := statedb.GetOrNewStateObject(call.From) from.SetBalance(math.MaxBig256) // Execute the call. - msg := callmsg{call} + msg := &core.Message{ + To: call.To, + From: call.From, + Value: call.Value, + GasLimit: call.Gas, + GasPrice: call.GasPrice, + Data: call.Data, + SkipAccountChecks: true, + } feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) - if msg.To() != nil { - if value, ok := feeCapacity[*msg.To()]; ok { - msg.CallMsg.BalanceTokenFee = value + if msg.To != nil { + if value, ok := feeCapacity[*msg.To]; ok { + msg.BalanceTokenFee = value } } evmContext := core.NewEVMContext(msg, block.Header(), b.blockchain, nil) @@ -350,7 +460,14 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa b.mu.Lock() defer b.mu.Unlock() - sender, err := types.Sender(types.HomesteadSigner{}, tx) + // Get the last block + block, err := b.blockByHash(ctx, b.pendingBlock.ParentHash()) + if err != nil { + return errors.New("could not fetch parent") + } + // Check transaction validity + signer := types.MakeSigner(b.blockchain.Config(), block.Number()) + sender, err := types.Sender(signer, tx) if err != nil { panic(fmt.Errorf("invalid transaction: %v", err)) } @@ -368,7 +485,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa statedb, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil) return nil } @@ -447,26 +564,11 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { statedb, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database(), nil) return nil } -// callmsg implements core.Message to allow passing it as a transaction simulator. -type callmsg struct { - tomochain.CallMsg -} - -func (m callmsg) From() common.Address { return m.CallMsg.From } -func (m callmsg) Nonce() uint64 { return 0 } -func (m callmsg) CheckNonce() bool { return false } -func (m callmsg) To() *common.Address { return m.CallMsg.To } -func (m callmsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callmsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callmsg) Value() *big.Int { return m.CallMsg.Value } -func (m callmsg) Data() []byte { return m.CallMsg.Data } -func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } - // filterBackend implements filters.Backend to support filtering for logs without // taking bloom-bits acceleration structures into account. type filterBackend struct { @@ -485,11 +587,11 @@ func (fb *filterBackend) HeaderByNumber(ctx context.Context, block rpc.BlockNumb } func (fb *filterBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { - return core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)), nil + return rawdb.GetBlockReceipts(fb.db, hash, rawdb.GetBlockNumber(fb.db, hash), fb.bc.Config()), nil } func (fb *filterBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*types.Log, error) { - receipts := core.GetBlockReceipts(fb.db, hash, core.GetBlockNumber(fb.db, hash)) + receipts := rawdb.GetBlockReceipts(fb.db, hash, rawdb.GetBlockNumber(fb.db, hash), fb.bc.Config()) if receipts == nil { return nil, nil } diff --git a/accounts/keystore/keystore_wallet.go b/accounts/keystore/keystore_wallet.go index 01ffd75a8e..91ac138786 100644 --- a/accounts/keystore/keystore_wallet.go +++ b/accounts/keystore/keystore_wallet.go @@ -90,7 +90,7 @@ func (w *keystoreWallet) SignHash(account accounts.Account, hash []byte) ([]byte if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignHash(account, hash) } @@ -106,7 +106,7 @@ func (w *keystoreWallet) SignTx(account accounts.Account, tx *types.Transaction, if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignTx(account, tx, chainID) } @@ -120,7 +120,7 @@ func (w *keystoreWallet) SignHashWithPassphrase(account accounts.Account, passph if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignHashWithPassphrase(account, passphrase, hash) } @@ -134,6 +134,6 @@ func (w *keystoreWallet) SignTxWithPassphrase(account accounts.Account, passphra if account.URL != (accounts.URL{}) && account.URL != w.account.URL { return nil, accounts.ErrUnknownAccount } - // Account seems valid, request the keystore to sign + // StateAccount seems valid, request the keystore to sign return w.keystore.SignTxWithPassphrase(account, passphrase, tx, chainID) } diff --git a/accounts/usbwallet/wallet.go b/accounts/usbwallet/wallet.go index d3cda1f21e..2cb2ca2ae7 100644 --- a/accounts/usbwallet/wallet.go +++ b/accounts/usbwallet/wallet.go @@ -319,7 +319,7 @@ func (w *wallet) selfDerive() { // Termination requested continue case reqc = <-w.deriveReq: - // Account discovery requested + // StateAccount discovery requested } // Derivation needs a chain and device access, skip if either unavailable w.stateLock.RLock() diff --git a/build/ci.go b/build/ci.go index ea44817049..6af2b18afe 100644 --- a/build/ci.go +++ b/build/ci.go @@ -14,6 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . +//go:build none // +build none /* @@ -23,14 +24,13 @@ Usage: go run build/ci.go Available commands are: - install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables - test [ -coverage ] [ packages... ] -- runs the tests - lint -- runs certain pre-selected linters - importkeys -- imports signing keys from env - xgo [ -alltools ] [ options ] -- cross builds according to options + install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables + test [ -coverage ] [ packages... ] -- runs the tests + lint -- runs certain pre-selected linters + importkeys -- imports signing keys from env + xgo [ -alltools ] [ options ] -- cross builds according to options For all commands, -n prevents execution of external programs (dry run mode). - */ package main @@ -62,6 +62,7 @@ var ( executablePath("rlpdump"), executablePath("swarm"), executablePath("wnode"), + executablePath("rlp/rlpgen"), } ) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 5d3b242898..75abb768ce 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -20,24 +20,25 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "os" + goruntime "runtime" "runtime/pprof" "time" - goruntime "runtime" + cli "gopkg.in/urfave/cli.v1" "github.com/tomochain/tomochain/cmd/evm/internal/compiler" "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/core/vm/runtime" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" - cli "gopkg.in/urfave/cli.v1" + "github.com/tomochain/tomochain/trie" ) var runCommand = cli.Command{ @@ -83,6 +84,7 @@ func runCmd(ctx *cli.Context) error { debugLogger *vm.StructLogger statedb *state.StateDB chainConfig *params.ChainConfig + preimages = ctx.Bool(DumpFlag.Name) sender = common.StringToAddress("sender") receiver = common.StringToAddress("receiver") ) @@ -98,11 +100,11 @@ func runCmd(ctx *cli.Context) error { gen := readGenesis(ctx.GlobalString(GenesisFlag.Name)) db := rawdb.NewMemoryDatabase() genesis := gen.ToBlock(db) - statedb, _ = state.New(genesis.Root(), state.NewDatabase(db)) + statedb, _ = state.New(genesis.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages}), nil) chainConfig = gen.Config } else { db := rawdb.NewMemoryDatabase() - statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ = state.New(common.Hash{}, state.NewDatabaseWithConfig(db, &trie.Config{Preimages: preimages}), nil) } if ctx.GlobalString(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name)) diff --git a/cmd/evm/staterunner.go b/cmd/evm/staterunner.go index 5499be6962..018a7c5262 100644 --- a/cmd/evm/staterunner.go +++ b/cmd/evm/staterunner.go @@ -94,7 +94,7 @@ func stateTestCmd(ctx *cli.Context) error { for _, st := range test.Subtests() { // Run the test and aggregate the result result := &StatetestResult{Name: key, Fork: st.Fork, Pass: true} - state, err := test.Run(st, cfg) + state, err := test.Run(st, cfg, false) if err != nil { // Test failed, mark as so and dump any state to aid debugging result.Pass, result.Error = false, err.Error() diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go index 6014f3c5a2..45a5e6cb4f 100644 --- a/cmd/faucet/faucet.go +++ b/cmd/faucet/faucet.go @@ -200,7 +200,7 @@ type faucet struct { index []byte // Index page to serve up on the web keystore *keystore.KeyStore // Keystore containing the single signer - account accounts.Account // Account funding user faucet requests + account accounts.Account // StateAccount funding user faucet requests nonce uint64 // Current pending nonce of the faucet price *big.Int // Current gas price to issue funds with diff --git a/cmd/gc/main.go b/cmd/gc/main.go index 567349ee42..8b3552dca9 100644 --- a/cmd/gc/main.go +++ b/cmd/gc/main.go @@ -3,9 +3,6 @@ package main import ( "flag" "fmt" - "github.com/tomochain/tomochain/core/rawdb" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/ethdb/leveldb" "os" "os/signal" "runtime" @@ -13,13 +10,14 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/cmd/utils" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" - "github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/eth" - "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/trie" ) @@ -54,15 +52,15 @@ func main() { flag.Parse() db, _ := leveldb.New(*dir, eth.DefaultConfig.DatabaseCache, utils.MakeDatabaseHandles(), "") lddb := rawdb.NewDatabase(db) - head := core.GetHeadBlockHash(lddb) - currentHeader := core.GetHeader(lddb, head, core.GetBlockNumber(lddb, head)) + head := rawdb.GetHeadBlockHash(lddb) + currentHeader := rawdb.GetHeader(lddb, head, rawdb.GetBlockNumber(lddb, head)) tridb := trie.NewDatabase(lddb) catchEventInterupt(db) cache, _ = lru.New(*cacheSize) go func() { for i := uint64(1); i <= currentHeader.Number.Uint64(); i++ { - hash := core.GetCanonicalHash(lddb, i) - root := core.GetHeader(lddb, hash, i).Root + hash := rawdb.GetCanonicalHash(lddb, i) + root := rawdb.GetHeader(lddb, hash, i).Root trieRoot, err := trie.NewSecure(root, tridb) if err != nil { continue @@ -81,9 +79,7 @@ func main() { atomic.StoreInt32(&finish, 1) if running { for _, address := range cleanAddress { - enc := trieRoot.trie.Get(address.Bytes()) - var data state.Account - rlp.DecodeBytes(enc, &data) + data, _ := trieRoot.trie.GetAccount(address) fmt.Println(time.Now().Format(time.RFC3339), "Start clean state address ", address.Hex(), " at block ", trieRoot.number) signerRoot, err := resolveHash(data.Root[:], db) if err != nil { diff --git a/cmd/puppeth/genesis.go b/cmd/puppeth/genesis.go index ebca4a082e..90dbbe8fcd 100644 --- a/cmd/puppeth/genesis.go +++ b/cmd/puppeth/genesis.go @@ -39,6 +39,7 @@ type cppEthereumGenesisSpec struct { EIP158ForkBlock hexutil.Uint64 `json:"EIP158ForkBlock"` ByzantiumForkBlock hexutil.Uint64 `json:"byzantiumForkBlock"` ConstantinopleForkBlock hexutil.Uint64 `json:"constantinopleForkBlock"` + EIP2718ForkBlock hexutil.Uint64 `json:"eip2718ForkBlock"` NetworkID hexutil.Uint64 `json:"networkID"` ChainID hexutil.Uint64 `json:"chainID"` MaximumExtraDataSize hexutil.Uint64 `json:"maximumExtraDataSize"` @@ -102,6 +103,7 @@ func newCppEthereumGenesisSpec(network string, genesis *core.Genesis) (*cppEther spec.Params.EIP158ForkBlock = (hexutil.Uint64)(genesis.Config.EIP158Block.Uint64()) spec.Params.ByzantiumForkBlock = (hexutil.Uint64)(genesis.Config.ByzantiumBlock.Uint64()) spec.Params.ConstantinopleForkBlock = (hexutil.Uint64)(math.MaxUint64) + spec.Params.EIP2718ForkBlock = (hexutil.Uint64)(genesis.Config.EIP2718Block.Uint64()) spec.Params.NetworkID = (hexutil.Uint64)(genesis.Config.ChainId.Uint64()) spec.Params.ChainID = (hexutil.Uint64)(genesis.Config.ChainId.Uint64()) diff --git a/cmd/puppeth/wizard_genesis.go b/cmd/puppeth/wizard_genesis.go index 1d278662c6..aefe32d788 100644 --- a/cmd/puppeth/wizard_genesis.go +++ b/cmd/puppeth/wizard_genesis.go @@ -18,27 +18,25 @@ package main import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" + "math/big" "math/rand" "time" - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" - "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/params" - - "context" - "math/big" - "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/accounts/abi/bind/backends" + "github.com/tomochain/tomochain/common" blockSignerContract "github.com/tomochain/tomochain/contracts/blocksigner" multiSignWalletContract "github.com/tomochain/tomochain/contracts/multisigwallet" randomizeContract "github.com/tomochain/tomochain/contracts/randomize" validatorContract "github.com/tomochain/tomochain/contracts/validator" + "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -56,6 +54,7 @@ func (w *wizard) makeGenesis() { EIP155Block: big.NewInt(3), EIP158Block: big.NewInt(3), ByzantiumBlock: big.NewInt(4), + EIP2718Block: big.NewInt(1), }, } // Figure out which consensus engine to choose @@ -79,6 +78,12 @@ func (w *wizard) makeGenesis() { Period: 15, Epoch: 30000, } + + // Query the user for some custom extras + fmt.Println() + fmt.Println("Specify your chain/network ID if you want an explicit one (default = random)") + genesis.Config.ChainId = new(big.Int).SetUint64(uint64(w.readDefaultInt(rand.Intn(65536)))) + fmt.Println() fmt.Println("How many seconds should blocks take? (default = 15)") genesis.Config.Clique.Period = uint64(w.readDefaultInt(15)) @@ -117,6 +122,8 @@ func (w *wizard) makeGenesis() { Epoch: 30000, Reward: 0, } + genesis.Config.ChainId = params.AllEthashProtocolChanges.ChainId + fmt.Println() fmt.Println("How many seconds should blocks take? (default = 2)") genesis.Config.Posv.Period = uint64(w.readDefaultInt(2)) @@ -177,7 +184,7 @@ func (w *wizard) makeGenesis() { // Validator Smart Contract Code pKey, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") addr := crypto.PubkeyToAddress(pKey.PublicKey) - contractBackend := backends.NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(1000000000)}}) + contractBackend := backends.NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(1_000_000_000_000_000_000)}}) transactOpts := bind.NewKeyedTransactor(pKey) validatorAddress, _, err := validatorContract.DeployValidator(transactOpts, contractBackend, signers, validatorCaps, owner) @@ -386,6 +393,10 @@ func (w *wizard) manageGenesis() { fmt.Printf("Which block should Byzantium come into effect? (default = %v)\n", w.conf.Genesis.Config.ByzantiumBlock) w.conf.Genesis.Config.ByzantiumBlock = w.readDefaultBigInt(w.conf.Genesis.Config.ByzantiumBlock) + fmt.Println() + fmt.Printf("Which block should EIP-2718 come into effect? (default = %v)\n", w.conf.Genesis.Config.EIP2718Block) + w.conf.Genesis.Config.EIP2718Block = w.readDefaultBigInt(w.conf.Genesis.Config.EIP2718Block) + out, _ := json.MarshalIndent(w.conf.Genesis.Config, "", " ") fmt.Printf("Chain configuration updated:\n\n%s\n", out) diff --git a/cmd/tomo/bugcmd.go b/cmd/tomo/bugcmd.go index 3174f73881..5cec10ad49 100644 --- a/cmd/tomo/bugcmd.go +++ b/cmd/tomo/bugcmd.go @@ -105,5 +105,4 @@ const header = `Please answer these questions before submitting your issue. Than #### What did you see instead? -#### System details -` +#### System details` diff --git a/cmd/tomo/chaincmd.go b/cmd/tomo/chaincmd.go index dc0a274ba5..e1c23c0cdf 100644 --- a/cmd/tomo/chaincmd.go +++ b/cmd/tomo/chaincmd.go @@ -66,6 +66,7 @@ It expects the genesis file as argument.`, utils.CacheFlag, utils.LightModeFlag, utils.GCModeFlag, + utils.SnapshotFlag, utils.CacheDatabaseFlag, utils.CacheGCFlag, }, @@ -450,7 +451,7 @@ func dump(ctx *cli.Context) error { fmt.Println("{}") utils.Fatalf("block not found") } else { - state, err := state.New(block.Root(), state.NewDatabase(chainDb)) + state, err := state.New(block.Root(), state.NewDatabase(chainDb), nil) if err != nil { utils.Fatalf("could not create new state: %v", err) } diff --git a/cmd/tomo/consolecmd_test.go b/cmd/tomo/consolecmd_test.go index 241373f521..894f55c698 100644 --- a/cmd/tomo/consolecmd_test.go +++ b/cmd/tomo/consolecmd_test.go @@ -52,7 +52,7 @@ func TestConsoleWelcome(t *testing.T) { tomo.SetTemplateFunc("goarch", func() string { return runtime.GOARCH }) tomo.SetTemplateFunc("gover", runtime.Version) tomo.SetTemplateFunc("tomover", func() string { return params.Version }) - tomo.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format(time.RFC1123) }) + tomo.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)") }) tomo.SetTemplateFunc("apis", func() string { return ipcAPIs }) // Verify the actual welcome message to the required template @@ -137,7 +137,7 @@ func testAttachWelcome(t *testing.T, tomo *testtomo, endpoint, apis string) { attach.SetTemplateFunc("gover", runtime.Version) attach.SetTemplateFunc("tomover", func() string { return params.Version }) attach.SetTemplateFunc("etherbase", func() string { return tomo.Etherbase }) - attach.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format(time.RFC1123) }) + attach.SetTemplateFunc("niltime", func() string { return time.Unix(1544771829, 0).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)") }) attach.SetTemplateFunc("ipc", func() bool { return strings.HasPrefix(endpoint, "ipc") }) attach.SetTemplateFunc("datadir", func() string { return tomo.Datadir }) attach.SetTemplateFunc("apis", func() string { return apis }) diff --git a/cmd/tomo/dao_test.go b/cmd/tomo/dao_test.go index 773f1ed152..768a7bb762 100644 --- a/cmd/tomo/dao_test.go +++ b/cmd/tomo/dao_test.go @@ -17,7 +17,6 @@ package main import ( - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -25,7 +24,7 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" ) // Genesis block for nodes which don't care about the DAO fork (i.e. not configured) @@ -130,7 +129,7 @@ func testDAOForkBlockNewChain(t *testing.T, test int, genesis string, expectBloc if genesis != "" { genesisHash = daoGenesisHash } - config, err := core.GetChainConfig(db, genesisHash) + config, err := rawdb.GetChainConfig(db, genesisHash) if err != nil { t.Errorf("test %d: failed to retrieve chain config: %v", test, err) return // we want to return here, the other checks can't make it past this point (nil panic). diff --git a/cmd/tomo/main.go b/cmd/tomo/main.go index 2a606fbb78..1b08a78ced 100644 --- a/cmd/tomo/main.go +++ b/cmd/tomo/main.go @@ -86,6 +86,7 @@ var ( utils.LightModeFlag, utils.SyncModeFlag, utils.GCModeFlag, + utils.SnapshotFlag, //utils.LightServFlag, //utils.LightPeersFlag, //utils.LightKDFFlag, @@ -93,6 +94,7 @@ var ( //utils.CacheDatabaseFlag, //utils.CacheGCFlag, //utils.TrieCacheGenFlag, + utils.CacheSnapshotFlag, utils.ListenPortFlag, utils.MaxPeersFlag, utils.MaxPendingPeersFlag, diff --git a/cmd/tomo/usage.go b/cmd/tomo/usage.go index f166d9aae9..8840af8394 100644 --- a/cmd/tomo/usage.go +++ b/cmd/tomo/usage.go @@ -123,15 +123,15 @@ var AppHelpFlagGroups = []flagGroup{ // utils.TxPoolLifetimeFlag, // }, //}, - //{ - // Name: "PERFORMANCE TUNING", - // Flags: []cli.Flag{ - // utils.CacheFlag, - // utils.CacheDatabaseFlag, - // utils.CacheGCFlag, - // utils.TrieCacheGenFlag, - // }, - //}, + { + Name: "PERFORMANCE TUNING", + Flags: []cli.Flag{ + utils.CacheFlag, + utils.CacheDatabaseFlag, + utils.CacheGCFlag, + utils.CacheSnapshotFlag, + }, + }, { Name: "ACCOUNT", Flags: []cli.Flag{ diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index a3787f7311..667098e90e 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -29,6 +29,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" @@ -271,7 +272,7 @@ func ImportPreimages(db ethdb.Database, fn string) error { // Accumulate the preimages and flush when enough ws gathered preimages[crypto.Keccak256Hash(blob)] = common.CopyBytes(blob) if len(preimages) > 1024 { - if err := core.WritePreimages(db, 0, preimages); err != nil { + if err := rawdb.WritePreimages(db, 0, preimages); err != nil { return err } preimages = make(map[common.Hash][]byte) @@ -279,7 +280,7 @@ func ImportPreimages(db ethdb.Database, fn string) error { } // Flush the last batch preimage data if len(preimages) > 0 { - return core.WritePreimages(db, 0, preimages) + return rawdb.WritePreimages(db, 0, preimages) } return nil } diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 59a0cdeaf0..90e85528b2 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -190,6 +190,10 @@ var ( Usage: `Blockchain garbage collection mode ("full", "archive")`, Value: "full", } + SnapshotFlag = cli.BoolFlag{ + Name: "snapshot", + Usage: `Enables snapshot-database mode -- experimental work in progress feature`, + } LightServFlag = cli.IntFlag{ Name: "lightserv", Usage: "Maximum percentage of time allowed for serving LES requests (0-90)", @@ -305,6 +309,11 @@ var ( Usage: "Percentage of cache memory allowance to use for trie pruning", Value: 25, } + CacheSnapshotFlag = cli.IntFlag{ + Name: "cache.snapshot", + Usage: "Percentage of cache memory allowance to use for snapshot caching (default = 10% full mode, 20% archive mode)", + Value: 10, + } // Miner settings StakingEnabledFlag = cli.BoolFlag{ Name: "mine", diff --git a/common/bytes.go b/common/bytes.go index ba00e8a4b2..1801cb1cae 100644 --- a/common/bytes.go +++ b/common/bytes.go @@ -119,3 +119,25 @@ func LeftPadBytes(slice []byte, l int) []byte { return padded } + +// TrimLeftZeroes returns a subslice of s without leading zeroes +func TrimLeftZeroes(s []byte) []byte { + idx := 0 + for ; idx < len(s); idx++ { + if s[idx] != 0 { + break + } + } + return s[idx:] +} + +// TrimRightZeroes returns a subslice of s without trailing zeroes +func TrimRightZeroes(s []byte) []byte { + idx := len(s) + for ; idx > 0; idx-- { + if s[idx-1] != 0 { + break + } + } + return s[:idx] +} diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index f63373e17e..5c03e332cc 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -40,6 +40,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/trie" ) const ( @@ -575,7 +576,7 @@ func (c *Clique) Finalize(chain consensus.ChainReader, header *types.Header, sta header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks diff --git a/consensus/clique/snapshot.go b/consensus/clique/snapshot.go index 3c2bf703d8..9a1e9e8846 100644 --- a/consensus/clique/snapshot.go +++ b/consensus/clique/snapshot.go @@ -32,7 +32,7 @@ import ( type Vote struct { Signer common.Address `json:"signer"` // Authorized signer that cast this vote Block uint64 `json:"block"` // Block number the vote was cast in (expire old votes) - Address common.Address `json:"address"` // Account being voted on to change its authorization + Address common.Address `json:"address"` // StateAccount being voted on to change its authorization Authorize bool `json:"authorize"` // Whether to authorize or deauthorize the voted account } diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index 12f63cfde7..7064569927 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -32,6 +32,7 @@ import ( "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // Ethash proof-of-work protocol constants. @@ -519,7 +520,7 @@ func (ethash *Ethash) Finalize(chain consensus.ChainReader, header *types.Header header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number)) // Header seems complete, assemble into a block and return - return types.NewBlock(header, txs, uncles, receipts), nil + return types.NewBlock(header, txs, uncles, receipts, new(trie.StackTrie)), nil } // Some weird constants to avoid constant memory allocs for them. diff --git a/consensus/posv/posv.go b/consensus/posv/posv.go index 0027104970..71f9e52f70 100644 --- a/consensus/posv/posv.go +++ b/consensus/posv/posv.go @@ -21,9 +21,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" "io/ioutil" "math/big" "math/rand" @@ -35,6 +32,7 @@ import ( "time" lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" @@ -50,6 +48,10 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -66,6 +68,7 @@ type Masternode struct { type TradingService interface { GetTradingStateRoot(block *types.Block, author common.Address) (common.Hash, error) GetTradingState(block *types.Block, author common.Address) (*tradingstate.TradingStateDB, error) + GetEmptyTradingState() (*tradingstate.TradingStateDB, error) HasTradingState(block *types.Block, author common.Address) bool GetStateCache() tradingstate.Database GetTriegc() *prque.Prque @@ -181,7 +184,7 @@ var ( // SignerFn is a signer callback function to request a hash to be signed by a // backing account. -//type SignerFn func(accounts.Account, []byte) ([]byte, error) +//type SignerFn func(accounts.StateAccount, []byte) ([]byte, error) // sigHash returns the hash which is used as input for the proof-of-stake-voting // signing. It is the hash of the entire header apart from the 65 byte signature @@ -985,7 +988,7 @@ func (c *Posv) Finalize(chain consensus.ChainReader, header *types.Header, state header.UncleHash = types.CalcUncleHash(nil) // Assemble and return the final block for sealing - return types.NewBlock(header, txs, nil, receipts), nil + return types.NewBlock(header, txs, nil, receipts, new(trie.StackTrie)), nil } // Authorize injects a private key into the consensus engine to mint new blocks @@ -1146,7 +1149,7 @@ func (c *Posv) CacheData(header *types.Header, txs []*types.Transaction, receipt signTxs := []*types.Transaction{} for _, tx := range txs { if tx.IsSigningTransaction() { - var b uint + var b uint64 for _, r := range receipts { if r.TxHash == tx.Hash() { if len(r.PostState) > 0 { diff --git a/consensus/posv/snapshot.go b/consensus/posv/snapshot.go index aef9e2a39f..01f9d50e42 100644 --- a/consensus/posv/snapshot.go +++ b/consensus/posv/snapshot.go @@ -32,7 +32,7 @@ import ( //type Vote struct { // Signer common.Address `json:"signer"` // Authorized signer that cast this vote // Block uint64 `json:"block"` // Block number the vote was cast in (expire old votes) -// Address common.Address `json:"address"` // Account being voted on to change its authorization +// Address common.Address `json:"address"` // StateAccount being voted on to change its authorization // Authorize bool `json:"authorize"` // Whether to authorize or deauthorize the voted account //} diff --git a/console/console_test.go b/console/console_test.go index 22527f4ddc..98f85c4b43 100644 --- a/console/console_test.go +++ b/console/console_test.go @@ -19,8 +19,6 @@ package console import ( "bytes" "errors" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "os" "strings" @@ -29,10 +27,13 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/console/prompt" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/internal/jsre" "github.com/tomochain/tomochain/node" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomoxlending" ) const ( @@ -67,10 +68,10 @@ func (p *hookedPrompter) PromptPassword(prompt string) (string, error) { func (p *hookedPrompter) PromptConfirm(prompt string) (bool, error) { return false, errors.New("not implemented") } -func (p *hookedPrompter) SetHistory(history []string) {} -func (p *hookedPrompter) AppendHistory(command string) {} -func (p *hookedPrompter) ClearHistory() {} -func (p *hookedPrompter) SetWordCompleter(completer WordCompleter) {} +func (p *hookedPrompter) SetHistory(history []string) {} +func (p *hookedPrompter) AppendHistory(command string) {} +func (p *hookedPrompter) ClearHistory() {} +func (p *hookedPrompter) SetWordCompleter(completer prompt.WordCompleter) {} // tester is a console test environment for the console tests to operate on. type tester struct { @@ -262,7 +263,7 @@ func TestPrettyError(t *testing.T) { defer tester.Close(t) tester.console.Evaluate("throw 'hello'") - want := jsre.ErrorColor("hello") + "\n" + want := jsre.ErrorColor("hello") + "\n\tat :1:1(1)\n\n" if output := tester.output.String(); output != want { t.Fatalf("pretty error mismatch: have %s, want %s", output, want) } diff --git a/contracts/utils.go b/contracts/utils.go index 4468b5de9a..eede7c43f6 100644 --- a/contracts/utils.go +++ b/contracts/utils.go @@ -39,6 +39,7 @@ import ( "github.com/tomochain/tomochain/contracts/blocksigner/contract" randomizeContract "github.com/tomochain/tomochain/contracts/randomize/contract" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" stateDatabase "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" @@ -336,7 +337,7 @@ func GetRewardForCheckpoint(c *posv.Posv, chain consensus.ChainReader, header *t block := chain.GetBlock(header.Hash(), i) txs := block.Transactions() if !chain.Config().IsTIPSigning(header.Number) { - receipts := core.GetBlockReceipts(c.GetDb(), header.Hash(), i) + receipts := rawdb.GetBlockReceipts(c.GetDb(), header.Hash(), i, chain.Config()) signData = c.CacheData(header, txs, receipts) } else { signData = c.CacheSigner(header.Hash(), txs) diff --git a/contracts/validator/validator_test.go b/contracts/validator/validator_test.go index c7a452d751..9cdb8bec87 100644 --- a/contracts/validator/validator_test.go +++ b/contracts/validator/validator_test.go @@ -60,10 +60,7 @@ func TestValidator(t *testing.T) { d := time.Now().Add(1000 * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), d) defer cancel() - code, _ := contractBackend.CodeAt(ctx, validatorAddress, nil) - t.Log("contract code", common.ToHex(code)) f := func(key, val common.Hash) bool { - t.Log(key.Hex(), val.Hex()) return true } contractBackend.ForEachStorageAt(ctx, validatorAddress, nil, f) diff --git a/core/bench_test.go b/core/bench_test.go index 137b57f031..5380398f91 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -18,7 +18,6 @@ package core import ( "crypto/ecdsa" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -235,13 +235,13 @@ func makeChainForBench(db ethdb.Database, full bool, count uint64) { ReceiptHash: types.EmptyRootHash, } hash = header.Hash() - WriteHeader(db, header) - WriteCanonicalHash(db, hash, n) - WriteTd(db, hash, n, big.NewInt(int64(n+1))) + rawdb.WriteHeader(db, header) + rawdb.WriteCanonicalHash(db, hash, n) + rawdb.WriteTd(db, hash, n, big.NewInt(int64(n+1))) if full || n == 0 { block := types.NewBlockWithHeader(header) - WriteBody(db, hash, n, block.Body()) - WriteBlockReceipts(db, hash, n, nil) + rawdb.WriteBody(db, hash, n, block.Body()) + rawdb.WriteBlockReceipts(db, hash, n, nil) } } } @@ -275,6 +275,8 @@ func benchReadChain(b *testing.B, full bool, count uint64) { } makeChainForBench(db, full, count) db.Close() + cacheConfig := defaultCacheConfig + cacheConfig.Disabled = true b.ReportAllocs() b.ResetTimer() @@ -284,7 +286,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) { if err != nil { b.Fatalf("error opening database at %v: %v", dir, err) } - chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) + chain, err := NewBlockChain(db, cacheConfig, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) if err != nil { b.Fatalf("error creating chain: %v", err) } @@ -293,8 +295,8 @@ func benchReadChain(b *testing.B, full bool, count uint64) { header := chain.GetHeaderByNumber(n) if full { hash := header.Hash() - GetBody(db, hash, n) - GetBlockReceipts(db, hash, n) + rawdb.GetBody(db, hash, n) + rawdb.GetBlockReceipts(db, hash, n, params.TestChainConfig) } } diff --git a/core/block_validator.go b/core/block_validator.go index 34fde4cedd..63e3f54383 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -18,6 +18,7 @@ package core import ( "fmt" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" @@ -27,6 +28,7 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" ) // BlockValidator is responsible for validating block headers, uncles and @@ -71,7 +73,7 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error { if hash := types.CalcUncleHash(block.Uncles()); hash != header.UncleHash { return fmt.Errorf("uncle root hash mismatch: have %x, want %x", hash, header.UncleHash) } - if hash := types.DeriveSha(block.Transactions()); hash != header.TxHash { + if hash := types.DeriveSha(block.Transactions(), new(trie.StackTrie)); hash != header.TxHash { return fmt.Errorf("transaction root hash mismatch: have %x, want %x", hash, header.TxHash) } return nil @@ -93,7 +95,7 @@ func (v *BlockValidator) ValidateState(block, parent *types.Block, statedb *stat return fmt.Errorf("invalid bloom (remote: %x local: %x)", header.Bloom, rbloom) } // Tre receipt Trie's root (R = (Tr [[H1, R1], ... [Hn, R1]])) - receiptSha := types.DeriveSha(receipts) + receiptSha := types.DeriveSha(receipts, new(trie.StackTrie)) if receiptSha != header.ReceiptHash { return fmt.Errorf("invalid receipt root hash (remote: %x local: %x)", header.ReceiptHash, receiptSha) } diff --git a/core/blockchain.go b/core/blockchain.go index f763189be7..2a912a4ce4 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -28,18 +28,18 @@ import ( "sync/atomic" "time" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" + lru "github.com/hashicorp/golang-lru" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" "github.com/tomochain/tomochain/accounts/abi/bind" - "github.com/tomochain/tomochain/tomox/tradingstate" - - lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/posv" contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -50,14 +50,40 @@ import ( "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" "github.com/tomochain/tomochain/trie" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( - blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) - CheckpointCh = make(chan int) - ErrNoGenesis = errors.New("Genesis not found in chain") + accountReadTimer = metrics.NewRegisteredTimer("chain/account/reads", nil) + accountHashTimer = metrics.NewRegisteredTimer("chain/account/hashes", nil) + accountUpdateTimer = metrics.NewRegisteredTimer("chain/account/updates", nil) + accountCommitTimer = metrics.NewRegisteredTimer("chain/account/commits", nil) + + storageReadTimer = metrics.NewRegisteredTimer("chain/storage/reads", nil) + storageHashTimer = metrics.NewRegisteredTimer("chain/storage/hashes", nil) + storageUpdateTimer = metrics.NewRegisteredTimer("chain/storage/updates", nil) + storageCommitTimer = metrics.NewRegisteredTimer("chain/storage/commits", nil) + + snapshotAccountReadTimer = metrics.NewRegisteredTimer("chain/snapshot/account/reads", nil) + snapshotStorageReadTimer = metrics.NewRegisteredTimer("chain/snapshot/storage/reads", nil) + snapshotCommitTimer = metrics.NewRegisteredTimer("chain/snapshot/commits", nil) + + blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) + blockValidationTimer = metrics.NewRegisteredTimer("chain/validation", nil) + blockExecutionTimer = metrics.NewRegisteredTimer("chain/execution", nil) + blockWriteTimer = metrics.NewRegisteredTimer("chain/write", nil) + blockReorgAddMeter = metrics.NewRegisteredMeter("chain/reorg/drop", nil) + blockReorgDropMeter = metrics.NewRegisteredMeter("chain/reorg/add", nil) + + blockPrefetchExecuteTimer = metrics.NewRegisteredTimer("chain/prefetch/executes", nil) + blockPrefetchInterruptMeter = metrics.NewRegisteredMeter("chain/prefetch/interrupts", nil) + + errInsertionInterrupted = errors.New("insertion is interrupted") + + CheckpointCh = make(chan int) + ErrNoGenesis = errors.New("Genesis not found in chain") ) const ( @@ -81,7 +107,18 @@ type CacheConfig struct { Disabled bool // Whether to disable trie write caching (archive node) TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk + SnapshotLimit int // Memory allowance (MB) to use for caching snapshot entries in memory + + SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it +} + +// defaultCacheConfig are the default caching values if none are specified by the +// user (also used during testing). +var defaultCacheConfig = &CacheConfig{ + TrieNodeLimit: 256, + TrieTimeLimit: 5 * time.Minute, } + type ResultProcessBlock struct { logs []*types.Log receipts []*types.Receipt @@ -112,8 +149,9 @@ type BlockChain struct { db ethdb.Database // Low level persistent database to store final content in tomoxDb ethdb.TomoxDatabase - triegc *prque.Prque // Priority queue mapping block numbers to tries to gc - gcproc time.Duration // Accumulates canonical block processing for trie dumping + snaps *snapshot.Tree // Snapshot tree for fast trie leaf access + triegc *prque.Prque // Priority queue mapping block numbers to tries to gc + gcproc time.Duration // Accumulates canonical block processing for trie dumping hc *HeaderChain rmLogsFeed event.Feed @@ -175,6 +213,8 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par cacheConfig = &CacheConfig{ TrieNodeLimit: 256 * 1024 * 1024, TrieTimeLimit: 5 * time.Minute, + SnapshotLimit: 256, + SnapshotWait: true, } } bodyCache, _ := lru.New(bodyCacheLimit) @@ -247,6 +287,10 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par } } } + // Load any existing snapshot, regenerating it if loading failed + if bc.cacheConfig.SnapshotLimit > 0 { + bc.snaps = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, bc.CurrentBlock().Root(), !bc.cacheConfig.SnapshotWait) + } // Take ownership of this particular state go bc.update() return bc, nil @@ -276,7 +320,7 @@ func (bc *BlockChain) addTomoxDb(tomoxDb ethdb.TomoxDatabase) { // assumes that the chain manager mutex is held. func (bc *BlockChain) loadLastState() error { // Restore the last known head block - head := GetHeadBlockHash(bc.db) + head := rawdb.GetHeadBlockHash(bc.db) if head == (common.Hash{}) { // Corrupt or empty database, init from scratch log.Warn("Empty database, resetting chain") @@ -289,13 +333,25 @@ func (bc *BlockChain) loadLastState() error { log.Warn("Head block missing, resetting chain", "hash", head) return bc.Reset() } + // Make sure the state associated with the block is available + if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil { + // Dangling block without a state associated, init from scratch + log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) + if err := bc.repair(¤tBlock); err != nil { + return err + } + rawdb.WriteHeadBlockHash(bc.db, currentBlock.Hash()) + } + + // Everything seems to be fine, set as the head block + bc.currentBlock.Store(currentBlock) + repair := false if common.Rewound != uint64(0) { repair = true } // Make sure the state associated with the block is available - _, err := state.New(currentBlock.Root(), bc.stateCache) - if err != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil { repair = true } else { engine, ok := bc.Engine().(*posv.Posv) @@ -344,7 +400,7 @@ func (bc *BlockChain) loadLastState() error { // Restore the last known head header currentHeader := currentBlock.Header() - if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) { + if head := rawdb.GetHeadHeaderHash(bc.db); head != (common.Hash{}) { if header := bc.GetHeaderByHash(head); header != nil { currentHeader = header } @@ -353,7 +409,7 @@ func (bc *BlockChain) loadLastState() error { // Restore the last known head fast block bc.currentFastBlock.Store(currentBlock) - if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) { + if head := rawdb.GetHeadFastBlockHash(bc.db); head != (common.Hash{}) { if block := bc.GetBlockByHash(head); block != nil { bc.currentFastBlock.Store(block) } @@ -385,7 +441,7 @@ func (bc *BlockChain) SetHead(head uint64) error { // Rewind the header chain, deleting all block bodies until then delFn := func(hash common.Hash, num uint64) { - DeleteBody(bc.db, hash, num) + rawdb.DeleteBody(bc.db, hash, num) } bc.hc.SetHead(head, delFn) currentHeader := bc.hc.CurrentHeader() @@ -402,7 +458,7 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.currentBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())) } if currentBlock := bc.CurrentBlock(); currentBlock != nil { - if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { + if _, err := state.New(currentBlock.Root(), bc.stateCache, bc.snaps); err != nil { // Rewound state missing, rolled back to before pivot, reset to genesis bc.currentBlock.Store(bc.genesisBlock) } @@ -420,10 +476,10 @@ func (bc *BlockChain) SetHead(head uint64) error { } currentBlock := bc.CurrentBlock() currentFastBlock := bc.CurrentFastBlock() - if err := WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil { log.Crit("Failed to reset head full block", "err", err) } - if err := WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil { + if err := rawdb.WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil { log.Crit("Failed to reset head fast block", "err", err) } return bc.loadLastState() @@ -445,6 +501,11 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error { bc.currentBlock.Store(block) bc.mu.Unlock() + // Destroy any existing state snapshot and regenerate it in the background + if bc.snaps != nil { + log.Info("Destroy any existing state snapshot and regenerate it in the background", "Snapshot", bc.snaps) + bc.snaps.Rebuild(block.Root()) + } log.Info("Committed new head block", "number", block.Number(), "hash", hash) return nil } @@ -501,7 +562,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) { // StateAt returns a new mutable state based on a particular point in time. func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return state.New(root, bc.stateCache) + return state.New(root, bc.stateCache, bc.snaps) } // OrderStateAt returns a new mutable state based on a particular point in time. @@ -518,6 +579,13 @@ func (bc *BlockChain) OrderStateAt(block *types.Block) (*tradingstate.TradingSta } else { return nil, err } + } else { + tomoxState, err := tomoXService.GetEmptyTradingState() + if err == nil { + return tomoxState, nil + } else { + return nil, err + } } } return nil, errors.New("Get tomox state fail") @@ -562,7 +630,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error { if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { log.Crit("Failed to write genesis block TD", "err", err) } - if err := WriteBlock(bc.db, genesis); err != nil { + if err := rawdb.WriteBlock(bc.db, genesis); err != nil { log.Crit("Failed to write genesis block", "err", err) } bc.genesisBlock = genesis @@ -585,7 +653,7 @@ func (bc *BlockChain) repair(head **types.Block) error { for { // Abort if we've rewound to a head block that does have associated state if (common.Rewound == uint64(0)) || ((*head).Number().Uint64() < common.Rewound) { - if _, err := state.New((*head).Root(), bc.stateCache); err == nil { + if _, err := state.New((*head).Root(), bc.stateCache, bc.snaps); err == nil { log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash()) engine, ok := bc.Engine().(*posv.Posv) if ok { @@ -658,13 +726,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error { // Note, this function assumes that the `mu` mutex is held! func (bc *BlockChain) insert(block *types.Block) { // If the block is on a side chain or an unknown one, force other heads onto it too - updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash() + updateHeads := rawdb.GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash() // Add the block to the canonical chain number scheme and mark as the head - if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil { + if err := rawdb.WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil { log.Crit("Failed to insert block number", "err", err) } - if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head block hash", "err", err) } bc.currentBlock.Store(block) @@ -681,7 +749,7 @@ func (bc *BlockChain) insert(block *types.Block) { if updateHeads { bc.hc.SetCurrentHeader(block.Header()) - if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil { + if err := rawdb.WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head fast block hash", "err", err) } bc.currentFastBlock.Store(block) @@ -701,7 +769,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body { body := cached.(*types.Body) return body } - body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash)) + body := rawdb.GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash)) if body == nil { return nil } @@ -717,7 +785,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue { if cached, ok := bc.bodyRLPCache.Get(hash); ok { return cached.(rlp.RawValue) } - body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash)) + body := rawdb.GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash)) if len(body) == 0 { return nil } @@ -731,7 +799,7 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool { if bc.blockCache.Contains(hash) { return true } - ok, _ := bc.db.Has(blockBodyKey(hash, number)) + ok, _ := bc.db.Has(rawdb.BlockBodyKey(number, hash)) return ok } @@ -774,7 +842,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { if block, ok := bc.blockCache.Get(hash); ok { return block.(*types.Block) } - block := GetBlock(bc.db, hash, number) + block := rawdb.GetBlock(bc.db, hash, number) if block == nil { return nil } @@ -791,7 +859,7 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block { // GetBlockByNumber retrieves a block from the database by number, caching it // (associated with its hash) if found. func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { - hash := GetCanonicalHash(bc.db, number) + hash := rawdb.GetCanonicalHash(bc.db, number) if hash == (common.Hash{}) { return nil } @@ -800,7 +868,7 @@ func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { // GetReceiptsByHash retrieves the receipts for all transactions in a given block. func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { - return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash)) + return rawdb.GetBlockReceipts(bc.db, hash, rawdb.GetBlockNumber(bc.db, hash), bc.chainConfig) } // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors. @@ -867,18 +935,28 @@ func (bc *BlockChain) SaveData() { // Make sure no inconsistent state is leaked during insertion bc.mu.Lock() defer bc.mu.Unlock() + // Ensure that the entirety of the state snapshot is journalled to disk. + var snapBase common.Hash + if bc.snaps != nil { + var err error + if snapBase, err = bc.snaps.Journal(bc.CurrentBlock().Root()); err != nil { + log.Error("Failed to journal state snapshot", "err", err) + } + } // Ensure the state of a recent block is also stored to disk before exiting. // We're writing three different states to catch different restart scenarios: // - HEAD: So we don't need to reprocess any blocks in the general case // - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle // - HEAD-127: So we have a hard limit on the number of blocks reexecuted if !bc.cacheConfig.Disabled { - var tradingTriedb *trie.Database - var lendingTriedb *trie.Database + var ( + tradingTriedb *trie.Database + lendingTriedb *trie.Database + tradingService posv.TradingService + lendingService posv.LendingService + ) engine, _ := bc.Engine().(*posv.Posv) triedb := bc.stateCache.TrieDB() - var tradingService posv.TradingService - var lendingService posv.LendingService if bc.Config().IsTIPTomoX(bc.CurrentBlock().Number()) && bc.chainConfig.Posv != nil && bc.CurrentBlock().NumberU64() > bc.chainConfig.Posv.Epoch && engine != nil { tradingService = engine.GetTomoXService() if tradingService != nil && tradingService.GetStateCache() != nil { @@ -918,6 +996,12 @@ func (bc *BlockChain) SaveData() { } } } + if snapBase != (common.Hash{}) { + log.Info("Writing snapshot state to disk", "root", snapBase) + if err := triedb.Commit(snapBase, true); err != nil { + log.Error("Failed to commit recent state trie", "err", err) + } + } for !bc.triegc.Empty() { triedb.Dereference(bc.triegc.PopItem().(common.Hash)) } @@ -996,12 +1080,12 @@ func (bc *BlockChain) Rollback(chain []common.Hash) { if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock.Hash() == hash { newFastBlock := bc.GetBlock(currentFastBlock.ParentHash(), currentFastBlock.NumberU64()-1) bc.currentFastBlock.Store(newFastBlock) - WriteHeadFastBlockHash(bc.db, newFastBlock.Hash()) + rawdb.WriteHeadFastBlockHash(bc.db, newFastBlock.Hash()) } if currentBlock := bc.CurrentBlock(); currentBlock.Hash() == hash { newBlock := bc.GetBlock(currentBlock.ParentHash(), currentBlock.NumberU64()-1) bc.currentBlock.Store(newBlock) - WriteHeadBlockHash(bc.db, newBlock.Hash()) + rawdb.WriteHeadBlockHash(bc.db, newBlock.Hash()) } } } @@ -1086,13 +1170,13 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ return i, fmt.Errorf("failed to set receipts data: %v", err) } // Write all the data out into the database - if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil { + if err := rawdb.WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil { return i, fmt.Errorf("failed to write block body: %v", err) } - if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { + if err := rawdb.WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { return i, fmt.Errorf("failed to write block receipts: %v", err) } - if err := WriteTxLookupEntries(batch, block); err != nil { + if err := rawdb.WriteTxLookupEntries(batch, block); err != nil { return i, fmt.Errorf("failed to write lookup metadata: %v", err) } stats.processed++ @@ -1118,7 +1202,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case currentFastBlock := bc.CurrentFastBlock() if bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64()).Cmp(td) < 0 { - if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil { + if err := rawdb.WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil { log.Crit("Failed to update head fast block hash", "err", err) } bc.currentFastBlock.Store(head) @@ -1148,7 +1232,7 @@ func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (e if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil { return err } - if err := WriteBlock(bc.db, block); err != nil { + if err := rawdb.WriteBlock(bc.db, block); err != nil { return err } return nil @@ -1178,7 +1262,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } // Write other block data using a batch. batch := bc.db.NewBatch() - if err := WriteBlock(batch, block); err != nil { + if err := rawdb.WriteBlock(batch, block); err != nil { return NonStatTy, err } root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number())) @@ -1324,7 +1408,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } } } - if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { + if err := rawdb.WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { return NonStatTy, err } // If the total difficulty is higher than our known, add it to the canonical chain @@ -1344,11 +1428,11 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } } // Write the positional metadata for transaction and receipt lookups - if err := WriteTxLookupEntries(batch, block); err != nil { + if err := rawdb.WriteTxLookupEntries(batch, block); err != nil { return NonStatTy, err } // Write hash preimages - if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil { + if err := rawdb.WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil { return NonStatTy, err } status = CanonStatTy @@ -1521,7 +1605,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } else { parent = chain[i-1] } - statedb, err := state.New(parent.Root(), bc.stateCache) + statedb, err := state.New(parent.Root(), bc.stateCache, bc.snaps) if err != nil { return i, events, coalescedLogs, err } @@ -1532,11 +1616,13 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } parentAuthor, _ := bc.Engine().Author(parent.Header()) // clear the previous dry-run cache - var tradingState *tradingstate.TradingStateDB - var lendingState *lendingstate.LendingStateDB - var tradingService posv.TradingService - var lendingService posv.LendingService - isSDKNode := false + var ( + tradingState *tradingstate.TradingStateDB + lendingState *lendingstate.LendingStateDB + tradingService posv.TradingService + lendingService posv.LendingService + isSDKNode = false + ) if bc.Config().IsTIPTomoX(block.Number()) && bc.chainConfig.Posv != nil && engine != nil && block.NumberU64() > bc.chainConfig.Posv.Epoch { tradingService = engine.GetTomoXService() lendingService = engine.GetLendingService() @@ -1627,6 +1713,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } feeCapacity := state.GetTRC21FeeCapacityFromStateWithCache(parent.Root(), statedb) // Process block using the parent state as reference point. + substart := time.Now() receipts, logs, usedGas, err := bc.processor.Process(block, statedb, tradingState, bc.vmConfig, feeCapacity) if err != nil { bc.reportBlock(block, receipts, err) @@ -1638,12 +1725,34 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty bc.reportBlock(block, receipts, err) return i, events, coalescedLogs, err } + // Update the metrics touched during block processing + accountReadTimer.Update(statedb.AccountReads) // Account reads are complete, we can mark them + storageReadTimer.Update(statedb.StorageReads) // Storage reads are complete, we can mark them + accountUpdateTimer.Update(statedb.AccountUpdates) // Account updates are complete, we can mark them + storageUpdateTimer.Update(statedb.StorageUpdates) // Storage updates are complete, we can mark them + snapshotAccountReadTimer.Update(statedb.SnapshotAccountReads) // Account reads are complete, we can mark them + snapshotStorageReadTimer.Update(statedb.SnapshotStorageReads) // Storage reads are complete, we can mark them + + triehash := statedb.AccountHashes + statedb.StorageHashes // Save to not double count in validation + trieproc := statedb.SnapshotAccountReads + statedb.AccountReads + statedb.AccountUpdates + trieproc += statedb.SnapshotStorageReads + statedb.StorageReads + statedb.StorageUpdates + + blockExecutionTimer.Update(time.Since(substart) - trieproc - triehash) + proctime := time.Since(bstart) // Write the block to the chain and get the status. status, err := bc.WriteBlockWithState(block, receipts, statedb, tradingState, lendingState) if err != nil { return i, events, coalescedLogs, err } + + // Update the metrics touched during block commit + accountCommitTimer.Update(statedb.AccountCommits) // Account commits are complete, we can mark them + storageCommitTimer.Update(statedb.StorageCommits) // Storage commits are complete, we can mark them + snapshotCommitTimer.Update(statedb.SnapshotCommits) // Snapshot commits are complete, we can mark them + + blockWriteTimer.Update(time.Since(substart) - statedb.AccountCommits - statedb.StorageCommits - statedb.SnapshotCommits) + if bc.chainConfig.Posv != nil { c := bc.engine.(*posv.Posv) coinbase := c.Signer() @@ -1813,7 +1922,7 @@ func (bc *BlockChain) getResultBlock(block *types.Block, verifiedM2 bool) (*Resu // Create a new statedb using the parent block and report an // error if it fails. var parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1) - statedb, err := state.New(parent.Root(), bc.stateCache) + statedb, err := state.New(parent.Root(), bc.stateCache, bc.snaps) if err != nil { return nil, err } @@ -2120,7 +2229,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // These logs are later announced as deleted. collectLogs = func(h common.Hash) { // Coalesce logs and set 'Removed'. - receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h)) + receipts := rawdb.GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h), bc.chainConfig) for _, receipt := range receipts { for _, log := range receipt.Logs { del := *log @@ -2189,7 +2298,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // insert the block in the canonical way, re-writing history bc.insert(newChain[i]) // write lookup entries for hash based transaction/receipt searches - if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil { + if err := rawdb.WriteTxLookupEntries(bc.db, newChain[i]); err != nil { return err } addedTxs = append(addedTxs, newChain[i].Transactions()...) @@ -2199,7 +2308,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // When transactions get deleted from the database that means the // receipts that were created in the fork must also be deleted for _, tx := range diff { - DeleteTxLookupEntry(bc.db, tx.Hash()) + rawdb.DeleteTxLookupEntry(bc.db, tx.Hash()) } if len(deletedLogs) > 0 { go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs}) diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 6860924112..a59b522c2d 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -17,8 +17,8 @@ package core import ( + "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "sync" @@ -27,11 +27,13 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) // Test fork of length N starting from block i @@ -113,7 +115,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { } return err } - statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache) + statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache, nil) if err != nil { return err } @@ -128,8 +130,8 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { return err } blockchain.mu.Lock() - WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) - WriteBlock(blockchain.db, block) + rawdb.WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) + rawdb.WriteBlock(blockchain.db, block) statedb.Commit(true) blockchain.mu.Unlock() } @@ -146,8 +148,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error } // Manually insert the header into the database, but don't reorganise (allows subsequent testing) blockchain.mu.Lock() - WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) - WriteHeader(blockchain.db, header) + rawdb.WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) + rawdb.WriteHeader(blockchain.db, header) blockchain.mu.Unlock() } return nil @@ -173,7 +175,7 @@ func TestLastBlock(t *testing.T) { if _, err := blockchain.InsertChain(blocks); err != nil { t.Fatalf("Failed to insert block: %v", err) } - if blocks[len(blocks)-1].Hash() != GetHeadBlockHash(blockchain.db) { + if blocks[len(blocks)-1].Hash() != rawdb.GetHeadBlockHash(blockchain.db) { t.Fatalf("Write/Get HeadBlockHash failed") } } @@ -617,18 +619,18 @@ func TestFastVsFullChains(t *testing.T) { } if fblock, ablock := fast.GetBlockByHash(hash), archive.GetBlockByHash(hash); fblock.Hash() != ablock.Hash() { t.Errorf("block #%d [%x]: block mismatch: have %v, want %v", num, hash, fblock, ablock) - } else if types.DeriveSha(fblock.Transactions()) != types.DeriveSha(ablock.Transactions()) { + } else if types.DeriveSha(fblock.Transactions(), new(trie.StackTrie)) != types.DeriveSha(ablock.Transactions(), new(trie.StackTrie)) { t.Errorf("block #%d [%x]: transactions mismatch: have %v, want %v", num, hash, fblock.Transactions(), ablock.Transactions()) } else if types.CalcUncleHash(fblock.Uncles()) != types.CalcUncleHash(ablock.Uncles()) { t.Errorf("block #%d [%x]: uncles mismatch: have %v, want %v", num, hash, fblock.Uncles(), ablock.Uncles()) } - if freceipts, areceipts := GetBlockReceipts(fastDb, hash, GetBlockNumber(fastDb, hash)), GetBlockReceipts(archiveDb, hash, GetBlockNumber(archiveDb, hash)); types.DeriveSha(freceipts) != types.DeriveSha(areceipts) { + if freceipts, areceipts := rawdb.GetBlockReceipts(fastDb, hash, rawdb.GetBlockNumber(fastDb, hash), fast.Config()), rawdb.GetBlockReceipts(archiveDb, hash, rawdb.GetBlockNumber(archiveDb, hash), fast.Config()); types.DeriveSha(freceipts, trie.NewStackTrie(nil)) != types.DeriveSha(areceipts, trie.NewStackTrie(nil)) { t.Errorf("block #%d [%x]: receipts mismatch: have %v, want %v", num, hash, freceipts, areceipts) } } // Check that the canonical chains are the same between the databases for i := 0; i < len(blocks)+1; i++ { - if fhash, ahash := GetCanonicalHash(fastDb, uint64(i)), GetCanonicalHash(archiveDb, uint64(i)); fhash != ahash { + if fhash, ahash := rawdb.GetCanonicalHash(fastDb, uint64(i)), rawdb.GetCanonicalHash(archiveDb, uint64(i)); fhash != ahash { t.Errorf("block #%d: canonical hash mismatch: have %v, want %v", i, fhash, ahash) } } @@ -804,28 +806,28 @@ func TestChainTxReorgs(t *testing.T) { // removed tx for i, tx := range (types.Transactions{pastDrop, freshDrop}) { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn != nil { + if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn != nil { t.Errorf("drop %d: tx %v found while shouldn't have been", i, txn) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt != nil { + if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt != nil { t.Errorf("drop %d: receipt %v found while shouldn't have been", i, rcpt) } } // added tx for i, tx := range (types.Transactions{pastAdd, freshAdd, futureAdd}) { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil { + if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn == nil { t.Errorf("add %d: expected tx to be found", i) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil { + if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt == nil { t.Errorf("add %d: expected receipt to be found", i) } } // shared tx for i, tx := range (types.Transactions{postponed, swapped}) { - if txn, _, _, _ := GetTransaction(db, tx.Hash()); txn == nil { + if txn, _, _, _ := rawdb.GetTransaction(db, tx.Hash()); txn == nil { t.Errorf("share %d: expected tx to be found", i) } - if rcpt, _, _, _ := GetReceipt(db, tx.Hash()); rcpt == nil { + if rcpt, _, _, _ := rawdb.GetReceipt(db, tx.Hash(), blockchain.Config()); rcpt == nil { t.Errorf("share %d: expected receipt to be found", i) } } @@ -980,14 +982,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) { // try to retrieve a block by its canonical hash and see if the block data can be retrieved. for { - ch := GetCanonicalHash(blockchain.db, block.NumberU64()) + ch := rawdb.GetCanonicalHash(blockchain.db, block.NumberU64()) if ch == (common.Hash{}) { continue // busy wait for canonical hash to be written } if ch != block.Hash() { t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex()) } - fb := GetBlock(blockchain.db, ch, block.NumberU64()) + fb := rawdb.GetBlock(blockchain.db, ch, block.NumberU64()) if fb == nil { t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex()) } @@ -1005,6 +1007,72 @@ func TestCanonicalBlockRetrieval(t *testing.T) { pend.Wait() } +// TestEIP2718Transition tests that an EIP-2718 transaction will be accepted +// after the fork block has passed. This is verified by sending an EIP-2817 +// paymaster transaction and then checking that the gas usage of a hot SLOAD +// and a cold SLOAD are calculated correctly. +func TestEIP2718Transition(t *testing.T) { + var ( + aa = common.HexToAddress("0x000000000000000000000000000000000000aaaa") + db = rawdb.NewMemoryDatabase() + + // A sender who makes transactions, has some funds + key, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + address = crypto.PubkeyToAddress(key.PublicKey) + funds = big.NewInt(1000000000000000) + gspec = &Genesis{ + Config: params.TestChainConfig, + Alloc: GenesisAlloc{ + address: {Balance: funds}, + // The address 0xAAAA sloads 0x00 and 0x01 + aa: { + Code: []byte{ + byte(vm.PC), + byte(vm.PC), + byte(vm.SLOAD), + byte(vm.SLOAD), + }, + Nonce: 0, + Balance: big.NewInt(0), + }, + }, + } + genesis = gspec.MustCommit(db) + ) + // Generate blocks + blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, block *BlockGen) { + block.SetCoinbase(common.Address{1}) + + // One transaction to 0xAAAA + signer := types.LatestSigner(gspec.Config) + tx, _ := types.SignNewTx(key, signer, &types.PaymasterTx{ + ChainID: gspec.Config.ChainId, + Nonce: 0, + To: &aa, + Gas: 30000, + GasPrice: new(big.Int), + PmPayload: address.Bytes(), + }) + block.AddTx(tx) + }) + + // Import the canonical chain + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) + defer blockchain.Stop() + + if n, err := blockchain.InsertChain(blocks); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", n, err) + } + + block := blockchain.GetBlockByNumber(1) + + // Expected gas is intrinsic + 2 * pc + 2 * SLOAD + expected := params.TxGas + vm.GasQuickStep*2 + params.SloadGasEIP150*2 + if block.GasUsed() != expected { + t.Fatalf("incorrect amount of gas spent: expected %d, got %d", expected, block.GasUsed()) + } +} + func TestEIP155Transition(t *testing.T) { // Configure and generate a sample block chain var ( @@ -1104,8 +1172,8 @@ func TestEIP155Transition(t *testing.T) { } }) _, err := blockchain.InsertChain(blocks) - if err != types.ErrInvalidChainId { - t.Error("expected error:", types.ErrInvalidChainId) + if have, want := err, types.ErrInvalidChainId; !errors.Is(have, want) { + t.Errorf("have %v, want %v", have, want) } } diff --git a/core/chain_indexer.go b/core/chain_indexer.go index 95190eea93..41f3919904 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -24,6 +24,7 @@ import ( "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -206,7 +207,7 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE // TODO(karalabe): This operation is expensive and might block, causing the event system to // potentially also lock up. We need to do with on a different thread somehow. - if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { + if h := rawdb.FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { c.newHead(h.Number.Uint64(), true) } } @@ -349,11 +350,11 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com } for number := section * c.sectionSize; number < (section+1)*c.sectionSize; number++ { - hash := GetCanonicalHash(c.chainDb, number) + hash := rawdb.GetCanonicalHash(c.chainDb, number) if hash == (common.Hash{}) { return common.Hash{}, fmt.Errorf("canonical block #%d unknown", number) } - header := GetHeader(c.chainDb, hash, number) + header := rawdb.GetHeader(c.chainDb, hash, number) if header == nil { return common.Hash{}, fmt.Errorf("block #%d [%x…] not found", number, hash[:4]) } else if header.ParentHash != lastHead { diff --git a/core/chain_indexer_test.go b/core/chain_indexer_test.go index a954c062d9..3a50819b9d 100644 --- a/core/chain_indexer_test.go +++ b/core/chain_indexer_test.go @@ -18,13 +18,13 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "testing" "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" ) @@ -92,10 +92,10 @@ func testChainIndexer(t *testing.T, count int) { inject := func(number uint64) { header := &types.Header{Number: big.NewInt(int64(number)), Extra: big.NewInt(rand.Int63()).Bytes()} if number > 0 { - header.ParentHash = GetCanonicalHash(db, number-1) + header.ParentHash = rawdb.GetCanonicalHash(db, number-1) } - WriteHeader(db, header) - WriteCanonicalHash(db, header.Hash(), number) + rawdb.WriteHeader(db, header) + rawdb.WriteCanonicalHash(db, header.Hash(), number) } // Start indexer with an already existing chain for i := uint64(0); i <= 100; i++ { diff --git a/core/chain_makers.go b/core/chain_makers.go index ac7c311fd2..1e2aeb88b0 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -18,12 +18,12 @@ package core import ( "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/misc" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -115,6 +115,15 @@ func (b *BlockGen) AddTxWithChain(bc *BlockChain, tx *types.Transaction) { } } +// AddUncheckedTx forcefully adds a transaction to the block without any +// validation. +// +// AddUncheckedTx will cause consensus failures when used during real +// chain processing. This is best used in conjunction with raw block insertion. +func (b *BlockGen) AddUncheckedTx(tx *types.Transaction) { + b.txs = append(b.txs, tx) +} + // Number returns the block number of the block being generated. func (b *BlockGen) Number() *big.Int { return new(big.Int).Set(b.header.Number) @@ -225,7 +234,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse return nil, nil } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), state.NewDatabase(db)) + statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil) if err != nil { panic(err) } diff --git a/core/error.go b/core/error.go index 63be6ab83d..14177d13fe 100644 --- a/core/error.go +++ b/core/error.go @@ -33,9 +33,23 @@ var ( // next one expected based on the local chain. ErrNonceTooHigh = errors.New("nonce too high") + // ErrNonceMax is returned if the nonce of a transaction sender account has + // maximum allowed value and would become invalid if incremented. + ErrNonceMax = errors.New("nonce has max value") + ErrNotPoSV = errors.New("Posv not found in config") ErrNotFoundM1 = errors.New("list M1 not found ") ErrStopPreparingBlock = errors.New("stop calculating a block not verified by M2") + + // ErrSenderNoEOA is returned if the sender of a transaction is a contract. + ErrSenderNoEOA = errors.New("sender not an eoa") + + // ErrGasUintOverflow is returned when calculating gas usage. + ErrGasUintOverflow = errors.New("gas uint64 overflow") + + // ErrInsufficientFundsForTransfer is returned if the transaction sender doesn't + // have enough funds for transfer(topmost call only). + ErrInsufficientFundsForTransfer = errors.New("insufficient funds for transfer") ) diff --git a/core/evm.go b/core/evm.go index 04636999b3..f3ac62a73f 100644 --- a/core/evm.go +++ b/core/evm.go @@ -26,7 +26,7 @@ import ( ) // NewEVMContext creates a new context for use in the EVM. -func NewEVMContext(msg Message, header *types.Header, chain consensus.ChainContext, author *common.Address) vm.Context { +func NewEVMContext(msg *Message, header *types.Header, chain consensus.ChainContext, author *common.Address) vm.Context { // If we don't have an explicit author (i.e. not mining), extract from the header var beneficiary common.Address if author == nil { @@ -38,13 +38,13 @@ func NewEVMContext(msg Message, header *types.Header, chain consensus.ChainConte CanTransfer: CanTransfer, Transfer: Transfer, GetHash: GetHashFn(header, chain), - Origin: msg.From(), + Origin: msg.From, Coinbase: beneficiary, BlockNumber: new(big.Int).Set(header.Number), Time: new(big.Int).Set(header.Time), Difficulty: new(big.Int).Set(header.Difficulty), GasLimit: header.GasLimit, - GasPrice: new(big.Int).Set(msg.GasPrice()), + GasPrice: new(big.Int).Set(msg.GasPrice), } } diff --git a/core/genesis.go b/core/genesis.go index e1b7185a41..77970085fd 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -22,10 +22,11 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" @@ -35,6 +36,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" ) //go:generate gencodec -type Genesis -field-override genesisSpecMarshaling -out gen_genesis.go @@ -140,10 +142,10 @@ func (e *GenesisMismatchError) Error() string { // SetupGenesisBlock writes or updates the genesis block in db. // The block that will be used is: // -// genesis == nil genesis != nil -// +------------------------------------------ -// db has no genesis | main-net default | genesis -// db has genesis | from DB | genesis (if compatible) +// genesis == nil genesis != nil +// +------------------------------------------ +// db has no genesis | main-net default | genesis +// db has genesis | from DB | genesis (if compatible) // // The stored chain configuration will be updated if it is compatible (i.e. does not // specify a fork block below the local head block). In case of a conflict, the @@ -156,7 +158,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig } // Just commit the new block if there is no stored genesis block. - stored := GetCanonicalHash(db, 0) + stored := rawdb.GetCanonicalHash(db, 0) if (stored == common.Hash{}) { if genesis == nil { log.Info("Writing default main-net genesis block") @@ -178,12 +180,12 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Get the existing chain configuration. newcfg := genesis.configOrDefault(stored) - storedcfg, err := GetChainConfig(db, stored) + storedcfg, err := rawdb.GetChainConfig(db, stored) if err != nil { - if err == ErrChainConfigNotFound { + if err == rawdb.ErrChainConfigNotFound { // This case happens if a genesis write was interrupted. log.Warn("Found genesis block without chain config") - err = WriteChainConfig(db, stored, newcfg) + err = rawdb.WriteChainConfig(db, stored, newcfg) } return newcfg, stored, err } @@ -196,15 +198,15 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Check config compatibility and write the config. Compatibility errors // are returned to the caller unless we're already at block zero. - height := GetBlockNumber(db, GetHeadHeaderHash(db)) - if height == missingNumber { + height := rawdb.GetBlockNumber(db, rawdb.GetHeadHeaderHash(db)) + if height == rawdb.MissingNumber { return newcfg, stored, fmt.Errorf("missing block number for head header hash") } compatErr := storedcfg.CheckCompatible(newcfg, height) if compatErr != nil && height != 0 && compatErr.RewindTo != 0 { return newcfg, stored, compatErr } - return newcfg, stored, WriteChainConfig(db, stored, newcfg) + return newcfg, stored, rawdb.WriteChainConfig(db, stored, newcfg) } func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { @@ -226,7 +228,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { if db == nil { db = rawdb.NewMemoryDatabase() } - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) for addr, account := range g.Alloc { statedb.AddBalance(addr, account.Balance) statedb.SetCode(addr, account.Code) @@ -258,7 +260,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { statedb.Commit(false) statedb.Database().TrieDB().Commit(root, true) - return types.NewBlock(head, nil, nil, nil) + return types.NewBlock(head, nil, nil, nil, new(trie.StackTrie)) } // Commit writes the block and state of a genesis specification to the database. @@ -268,29 +270,29 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { if block.Number().Sign() != 0 { return nil, fmt.Errorf("can't commit genesis block with number > 0") } - if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { + if err := rawdb.WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { return nil, err } - if err := WriteBlock(db, block); err != nil { + if err := rawdb.WriteBlock(db, block); err != nil { return nil, err } - if err := WriteBlockReceipts(db, block.Hash(), block.NumberU64(), nil); err != nil { + if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), nil); err != nil { return nil, err } - if err := WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { + if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { return nil, err } - if err := WriteHeadBlockHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil { return nil, err } - if err := WriteHeadHeaderHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadHeaderHash(db, block.Hash()); err != nil { return nil, err } config := g.Config if config == nil { config = params.AllEthashProtocolChanges } - return block, WriteChainConfig(db, block.Hash(), config) + return block, rawdb.WriteChainConfig(db, block.Hash(), config) } // MustCommit writes the genesis block and state to db, panicking on error. diff --git a/core/genesis_test.go b/core/genesis_test.go index 177798a5d2..ee32b6705d 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -17,7 +17,6 @@ package core import ( - "github.com/tomochain/tomochain/core/rawdb" "math/big" "reflect" "testing" @@ -25,6 +24,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/params" @@ -155,7 +155,7 @@ func TestSetupGenesis(t *testing.T) { t.Errorf("%s: returned hash %s, want %s", test.name, hash.Hex(), test.wantHash.Hex()) } else if err == nil { // Check database content. - stored := GetBlock(db, test.wantHash, 0) + stored := rawdb.GetBlock(db, test.wantHash, 0) if stored.Hash() != test.wantHash { t.Errorf("%s: block in DB has hash %s, want %s", test.name, stored.Hash(), test.wantHash) } diff --git a/core/headerchain.go b/core/headerchain.go index 8365f2127d..feed409ca1 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -26,9 +26,11 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -66,9 +68,9 @@ type HeaderChain struct { } // NewHeaderChain creates a new HeaderChain structure. -// getValidator should return the parent's validator -// procInterrupt points to the parent's interrupt semaphore -// wg points to the parent's shutdown wait group +// getValidator should return the parent's validator +// procInterrupt points to the parent's interrupt semaphore +// wg points to the parent's shutdown wait group func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, procInterrupt func() bool) (*HeaderChain, error) { headerCache, _ := lru.New(headerCacheLimit) tdCache, _ := lru.New(tdCacheLimit) @@ -97,7 +99,7 @@ func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine c } hc.currentHeader.Store(hc.genesisHeader) - if head := GetHeadBlockHash(chainDb); head != (common.Hash{}) { + if head := rawdb.GetHeadBlockHash(chainDb); head != (common.Hash{}) { if chead := hc.GetHeaderByHash(head); chead != nil { hc.currentHeader.Store(chead) } @@ -113,8 +115,8 @@ func (hc *HeaderChain) GetBlockNumber(hash common.Hash) uint64 { if cached, ok := hc.numberCache.Get(hash); ok { return cached.(uint64) } - number := GetBlockNumber(hc.chainDb, hash) - if number != missingNumber { + number := rawdb.GetBlockNumber(hc.chainDb, hash) + if number != rawdb.MissingNumber { hc.numberCache.Add(hash, number) } return number @@ -147,7 +149,7 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er if err := hc.WriteTd(hash, number, externTd); err != nil { log.Crit("Failed to write header total difficulty", "err", err) } - if err := WriteHeader(hc.chainDb, header); err != nil { + if err := rawdb.WriteHeader(hc.chainDb, header); err != nil { log.Crit("Failed to write header content", "err", err) } // If the total difficulty is higher than our known, add it to the canonical chain @@ -156,11 +158,11 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er if externTd.Cmp(localTd) > 0 || (externTd.Cmp(localTd) == 0 && mrand.Float64() < 0.5) { // Delete any canonical number assignments above the new head for i := number + 1; ; i++ { - hash := GetCanonicalHash(hc.chainDb, i) + hash := rawdb.GetCanonicalHash(hc.chainDb, i) if hash == (common.Hash{}) { break } - DeleteCanonicalHash(hc.chainDb, i) + rawdb.DeleteCanonicalHash(hc.chainDb, i) } // Overwrite any stale canonical number assignments var ( @@ -168,18 +170,18 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er headNumber = header.Number.Uint64() - 1 headHeader = hc.GetHeader(headHash, headNumber) ) - for GetCanonicalHash(hc.chainDb, headNumber) != headHash { - WriteCanonicalHash(hc.chainDb, headHash, headNumber) + for rawdb.GetCanonicalHash(hc.chainDb, headNumber) != headHash { + rawdb.WriteCanonicalHash(hc.chainDb, headHash, headNumber) headHash = headHeader.ParentHash headNumber = headHeader.Number.Uint64() - 1 headHeader = hc.GetHeader(headHash, headNumber) } // Extend the canonical chain with the new header - if err := WriteCanonicalHash(hc.chainDb, hash, number); err != nil { + if err := rawdb.WriteCanonicalHash(hc.chainDb, hash, number); err != nil { log.Crit("Failed to insert header number", "err", err) } - if err := WriteHeadHeaderHash(hc.chainDb, hash); err != nil { + if err := rawdb.WriteHeadHeaderHash(hc.chainDb, hash); err != nil { log.Crit("Failed to insert head header hash", "err", err) } hc.currentHeaderHash = hash @@ -316,7 +318,7 @@ func (hc *HeaderChain) GetTd(hash common.Hash, number uint64) *big.Int { if cached, ok := hc.tdCache.Get(hash); ok { return cached.(*big.Int) } - td := GetTd(hc.chainDb, hash, number) + td := rawdb.GetTd(hc.chainDb, hash, number) if td == nil { return nil } @@ -334,7 +336,7 @@ func (hc *HeaderChain) GetTdByHash(hash common.Hash) *big.Int { // WriteTd stores a block's total difficulty into the database, also caching it // along the way. func (hc *HeaderChain) WriteTd(hash common.Hash, number uint64, td *big.Int) error { - if err := WriteTd(hc.chainDb, hash, number, td); err != nil { + if err := rawdb.WriteTd(hc.chainDb, hash, number, td); err != nil { return err } hc.tdCache.Add(hash, new(big.Int).Set(td)) @@ -348,7 +350,7 @@ func (hc *HeaderChain) GetHeader(hash common.Hash, number uint64) *types.Header if header, ok := hc.headerCache.Get(hash); ok { return header.(*types.Header) } - header := GetHeader(hc.chainDb, hash, number) + header := rawdb.GetHeader(hc.chainDb, hash, number) if header == nil { return nil } @@ -368,14 +370,14 @@ func (hc *HeaderChain) HasHeader(hash common.Hash, number uint64) bool { if hc.numberCache.Contains(hash) || hc.headerCache.Contains(hash) { return true } - ok, _ := hc.chainDb.Has(headerKey(hash, number)) + ok, _ := hc.chainDb.Has(rawdb.HeaderKey(number, hash)) return ok } // GetHeaderByNumber retrieves a block header from the database by number, // caching it (associated with its hash) if found. func (hc *HeaderChain) GetHeaderByNumber(number uint64) *types.Header { - hash := GetCanonicalHash(hc.chainDb, number) + hash := rawdb.GetCanonicalHash(hc.chainDb, number) if hash == (common.Hash{}) { return nil } @@ -390,7 +392,7 @@ func (hc *HeaderChain) CurrentHeader() *types.Header { // SetCurrentHeader sets the current head header of the canonical chain. func (hc *HeaderChain) SetCurrentHeader(head *types.Header) { - if err := WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil { + if err := rawdb.WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil { log.Crit("Failed to insert head header hash", "err", err) } hc.currentHeader.Store(head) @@ -416,13 +418,13 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) { if delFn != nil { delFn(hash, num) } - DeleteHeader(hc.chainDb, hash, num) - DeleteTd(hc.chainDb, hash, num) + rawdb.DeleteHeader(hc.chainDb, hash, num) + rawdb.DeleteTd(hc.chainDb, hash, num) hc.currentHeader.Store(hc.GetHeader(hdr.ParentHash, hdr.Number.Uint64()-1)) } // Roll back the canonical chain numbering for i := height; i > head; i-- { - DeleteCanonicalHash(hc.chainDb, i) + rawdb.DeleteCanonicalHash(hc.chainDb, i) } // Clear out any stale content from the caches hc.headerCache.Purge() @@ -434,7 +436,7 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) { } hc.currentHeaderHash = hc.CurrentHeader().Hash() - if err := WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil { + if err := rawdb.WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil { log.Crit("Failed to reset head header hash", "err", err) } } diff --git a/core/database_util.go b/core/rawdb/accessors_chain.go similarity index 51% rename from core/database_util.go rename to core/rawdb/accessors_chain.go index a5ab18687d..40a5b1d3ee 100644 --- a/core/database_util.go +++ b/core/rawdb/accessors_chain.go @@ -14,22 +14,17 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package core +package rawdb import ( "bytes" "encoding/binary" - "encoding/json" - "errors" - "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" - "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -44,46 +39,6 @@ type DatabaseDeleter interface { Delete(key []byte) error } -var ( - headHeaderKey = []byte("LastHeader") - headBlockKey = []byte("LastBlock") - headFastKey = []byte("LastFast") - trieSyncKey = []byte("TrieSync") - - // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). - headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header - tdSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + tdSuffix -> td - numSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + numSuffix -> hash - blockHashPrefix = []byte("H") // blockHashPrefix + hash -> num (uint64 big endian) - bodyPrefix = []byte("b") // bodyPrefix + num (uint64 big endian) + hash -> block body - blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts - lookupPrefix = []byte("l") // lookupPrefix + hash -> transaction/receipt lookup metadata - bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits - - preimagePrefix = "secure-key-" // preimagePrefix + hash -> preimage - configPrefix = []byte("ethereum-config-") // config prefix for the db - - // Chain index prefixes (use `i` + single byte to avoid mixing data types). - BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress - - // used by old db, now only used for conversion - oldReceiptsPrefix = []byte("receipts-") - oldTxMetaSuffix = []byte{0x01} - - ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error - - preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil) - preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) -) - -// TxLookupEntry is a positional metadata to help looking up the data content of -// a transaction or receipt given only its hash. -type TxLookupEntry struct { - BlockHash common.Hash - BlockIndex uint64 - Index uint64 -} - // encodeBlockNumber encodes a block number as big endian uint64 func encodeBlockNumber(number uint64) []byte { enc := make([]byte, 8) @@ -93,23 +48,23 @@ func encodeBlockNumber(number uint64) []byte { // GetCanonicalHash retrieves a hash assigned to a canonical block number. func GetCanonicalHash(db DatabaseReader, number uint64) common.Hash { - data, _ := db.Get(append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...)) + data, _ := db.Get(headerHashKey(number)) if len(data) == 0 { return common.Hash{} } return common.BytesToHash(data) } -// missingNumber is returned by GetBlockNumber if no header with the +// MissingNumber is returned by GetBlockNumber if no header with the // given block hash has been stored in the database -const missingNumber = uint64(0xffffffffffffffff) +const MissingNumber = uint64(0xffffffffffffffff) // GetBlockNumber returns the block number assigned to a block hash // if the corresponding header is present in the database func GetBlockNumber(db DatabaseReader, hash common.Hash) uint64 { - data, _ := db.Get(append(blockHashPrefix, hash.Bytes()...)) + data, _ := db.Get(headerNumberKey(hash)) if len(data) != 8 { - return missingNumber + return MissingNumber } return binary.BigEndian.Uint64(data) } @@ -149,7 +104,7 @@ func GetHeadFastBlockHash(db DatabaseReader) common.Hash { } // GetTrieSyncProgress retrieves the number of tries nodes fast synced to allow -// reportinc correct numbers across restarts. +// reporting correct numbers across restarts. func GetTrieSyncProgress(db DatabaseReader) uint64 { data, _ := db.Get(trieSyncKey) if len(data) == 0 { @@ -161,7 +116,7 @@ func GetTrieSyncProgress(db DatabaseReader) uint64 { // GetHeaderRLP retrieves a block header in its raw RLP database encoding, or nil // if the header's not found. func GetHeaderRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { - data, _ := db.Get(headerKey(hash, number)) + data, _ := db.Get(HeaderKey(number, hash)) return data } @@ -182,19 +137,11 @@ func GetHeader(db DatabaseReader, hash common.Hash, number uint64) *types.Header // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { - data, _ := db.Get(blockBodyKey(hash, number)) + data, _ := db.Get(BlockBodyKey(number, hash)) return data } -func headerKey(hash common.Hash, number uint64) []byte { - return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) -} - -func blockBodyKey(hash common.Hash, number uint64) []byte { - return append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) -} - -// GetBody retrieves the block body (transactons, uncles) corresponding to the +// GetBody retrieves the block body (transactions, uncles) corresponding to the // hash, nil if none found. func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body { data := GetBodyRLP(db, hash, number) @@ -212,7 +159,7 @@ func GetBody(db DatabaseReader, hash common.Hash, number uint64) *types.Body { // GetTd retrieves a block's total difficulty corresponding to the hash, nil if // none found. func GetTd(db DatabaseReader, hash common.Hash, number uint64) *big.Int { - data, _ := db.Get(append(append(append(headerPrefix, encodeBlockNumber(number)...), hash[:]...), tdSuffix...)) + data, _ := db.Get(headerTDKey(number, hash)) if len(data) == 0 { return nil } @@ -244,14 +191,25 @@ func GetBlock(db DatabaseReader, hash common.Hash, number uint64) *types.Block { return types.NewBlockWithHeader(header).WithBody(body.Transactions, body.Uncles) } -// GetBlockReceipts retrieves the receipts generated by the transactions included -// in a block given by its hash. -func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts { +// ReadReceiptsRLP retrieves all the transaction receipts belonging to a block in RLP encoding. +func ReadReceiptsRLP(db DatabaseReader, hash common.Hash, number uint64) rlp.RawValue { data, _ := db.Get(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash[:]...)) if len(data) == 0 { return nil } - storageReceipts := []*types.ReceiptForStorage{} + return data +} + +// ReadRawReceipts retrieves all the transaction receipts belonging to a block. +// The receipt metadata fields are not guaranteed to be populated, so they +// should not be used. Use ReadReceipts instead if the metadata is needed. +func ReadRawReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts { + // Retrieve the flattened receipt slice + data := ReadReceiptsRLP(db, hash, number) + if len(data) == 0 { + return nil + } + var storageReceipts []*types.ReceiptForStorage if err := rlp.DecodeBytes(data, &storageReceipts); err != nil { log.Error("Invalid receipt array RLP", "hash", hash, "err", err) return nil @@ -263,100 +221,30 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types. return receipts } -// GetTxLookupEntry retrieves the positional metadata associated with a transaction -// hash to allow retrieving the transaction or receipt by hash. -func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) { - // Load the positional metadata from disk and bail if it fails - data, _ := db.Get(append(lookupPrefix, hash.Bytes()...)) - if len(data) == 0 { - return common.Hash{}, 0, 0 - } - // Parse and return the contents of the lookup entry - var entry TxLookupEntry - if err := rlp.DecodeBytes(data, &entry); err != nil { - log.Error("Invalid lookup entry RLP", "hash", hash, "err", err) - return common.Hash{}, 0, 0 - } - return entry.BlockHash, entry.BlockIndex, entry.Index -} - -// GetTransaction retrieves a specific transaction from the database, along with -// its added positional metadata. -func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, common.Hash, uint64, uint64) { - // Retrieve the lookup metadata and resolve the transaction from the body - blockHash, blockNumber, txIndex := GetTxLookupEntry(db, hash) - - if blockHash != (common.Hash{}) { - body := GetBody(db, blockHash, blockNumber) - if body == nil || len(body.Transactions) <= int(txIndex) { - log.Error("Transaction referenced missing", "number", blockNumber, "hash", blockHash, "index", txIndex) - return nil, common.Hash{}, 0, 0 - } - return body.Transactions[txIndex], blockHash, blockNumber, txIndex - } - // Old transaction representation, load the transaction and it's metadata separately - data, _ := db.Get(hash.Bytes()) - if len(data) == 0 { - return nil, common.Hash{}, 0, 0 - } - var tx types.Transaction - if err := rlp.DecodeBytes(data, &tx); err != nil { - return nil, common.Hash{}, 0, 0 - } - // Retrieve the blockchain positional metadata - data, _ = db.Get(append(hash.Bytes(), oldTxMetaSuffix...)) - if len(data) == 0 { - return nil, common.Hash{}, 0, 0 - } - var entry TxLookupEntry - if err := rlp.DecodeBytes(data, &entry); err != nil { - return nil, common.Hash{}, 0, 0 - } - return &tx, entry.BlockHash, entry.BlockIndex, entry.Index -} - -// GetReceipt retrieves a specific transaction receipt from the database, along with -// its added positional metadata. -func GetReceipt(db DatabaseReader, hash common.Hash) (*types.Receipt, common.Hash, uint64, uint64) { - // Retrieve the lookup metadata and resolve the receipt from the receipts - blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash) - - if blockHash != (common.Hash{}) { - receipts := GetBlockReceipts(db, blockHash, blockNumber) - if len(receipts) <= int(receiptIndex) { - log.Error("Receipt refereced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) - return nil, common.Hash{}, 0, 0 - } - return receipts[receiptIndex], blockHash, blockNumber, receiptIndex +// GetBlockReceipts retrieves the receipts generated by the transactions included +// in a block given by its hash. +func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64, config *params.ChainConfig) types.Receipts { + // We're deriving many fields from the block body, retrieve beside the receipt + receipts := ReadRawReceipts(db, hash, number) + if receipts == nil { + return nil } - // Old receipt representation, load the receipt and set an unknown metadata - data, _ := db.Get(append(oldReceiptsPrefix, hash[:]...)) - if len(data) == 0 { - return nil, common.Hash{}, 0, 0 + body := GetBody(db, hash, number) + if body == nil { + log.Error("Missing body but have receipt", "hash", hash, "number", number) + return nil } - var receipt types.ReceiptForStorage - err := rlp.DecodeBytes(data, &receipt) - if err != nil { - log.Error("Invalid receipt RLP", "hash", hash, "err", err) + if err := receipts.DeriveFields(config, hash, number, body.Transactions); err != nil { + log.Error("Failed to derive block receipts fields", "hash", hash, "number", number, "err", err) + return nil } - return (*types.Receipt)(&receipt), common.Hash{}, 0, 0 -} - -// GetBloomBits retrieves the compressed bloom bit vector belonging to the given -// section and bit index from the. -func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { - key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...) - binary.BigEndian.PutUint16(key[1:], uint16(bit)) - binary.BigEndian.PutUint64(key[3:], section) - - return db.Get(key) + return receipts } // WriteCanonicalHash stores the canonical hash for the given block number. func WriteCanonicalHash(db ethdb.KeyValueWriter, hash common.Hash, number uint64) error { - key := append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...) - if err := db.Put(key, hash.Bytes()); err != nil { + if err := db.Put(headerHashKey(number), hash.Bytes()); err != nil { log.Crit("Failed to store number to hash mapping", "err", err) } return nil @@ -401,15 +289,13 @@ func WriteHeader(db ethdb.KeyValueWriter, header *types.Header) error { if err != nil { return err } - hash := header.Hash().Bytes() + hash := header.Hash() num := header.Number.Uint64() encNum := encodeBlockNumber(num) - key := append(blockHashPrefix, hash...) - if err := db.Put(key, encNum); err != nil { + if err := db.Put(headerNumberKey(hash), encNum); err != nil { log.Crit("Failed to store hash to number mapping", "err", err) } - key = append(append(headerPrefix, encNum...), hash...) - if err := db.Put(key, data); err != nil { + if err := db.Put(headerKey(num, hash), data); err != nil { log.Crit("Failed to store header", "err", err) } return nil @@ -426,8 +312,7 @@ func WriteBody(db ethdb.KeyValueWriter, hash common.Hash, number uint64, body *t // WriteBodyRLP writes a serialized body of a block into the database. func WriteBodyRLP(db ethdb.KeyValueWriter, hash common.Hash, number uint64, rlp rlp.RawValue) error { - key := append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) - if err := db.Put(key, rlp); err != nil { + if err := db.Put(BlockBodyKey(number, hash), rlp); err != nil { log.Crit("Failed to store block body", "err", err) } return nil @@ -439,8 +324,7 @@ func WriteTd(db ethdb.KeyValueWriter, hash common.Hash, number uint64, td *big.I if err != nil { return err } - key := append(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...), tdSuffix...) - if err := db.Put(key, data); err != nil { + if err := db.Put(headerTDKey(number, hash), data); err != nil { log.Crit("Failed to store block total difficulty", "err", err) } return nil @@ -473,66 +357,31 @@ func WriteBlockReceipts(db ethdb.KeyValueWriter, hash common.Hash, number uint64 return err } // Store the flattened receipt slice - key := append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...) - if err := db.Put(key, bytes); err != nil { + if err := db.Put(blockReceiptsKey(number, hash), bytes); err != nil { log.Crit("Failed to store block receipts", "err", err) } return nil } -// WriteTxLookupEntries stores a positional metadata for every transaction from -// a block, enabling hash based transaction and receipt lookups. -func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error { - // Iterate over each transaction and encode its metadata - for i, tx := range block.Transactions() { - entry := TxLookupEntry{ - BlockHash: block.Hash(), - BlockIndex: block.NumberU64(), - Index: uint64(i), - } - data, err := rlp.EncodeToBytes(entry) - if err != nil { - return err - } - if err := db.Put(append(lookupPrefix, tx.Hash().Bytes()...), data); err != nil { - return err - } - } - return nil -} - -// WriteBloomBits writes the compressed bloom bits vector belonging to the given -// section and bit index. -func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) { - key := append(append(bloomBitsPrefix, make([]byte, 10)...), head.Bytes()...) - - binary.BigEndian.PutUint16(key[1:], uint16(bit)) - binary.BigEndian.PutUint64(key[3:], section) - - if err := db.Put(key, bits); err != nil { - log.Crit("Failed to store bloom bits", "err", err) - } -} - // DeleteCanonicalHash removes the number to hash canonical mapping. func DeleteCanonicalHash(db DatabaseDeleter, number uint64) { - db.Delete(append(append(headerPrefix, encodeBlockNumber(number)...), numSuffix...)) + db.Delete(headerHashKey(number)) } // DeleteHeader removes all block header data associated with a hash. func DeleteHeader(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(blockHashPrefix, hash.Bytes()...)) - db.Delete(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...)) + db.Delete(headerNumberKey(hash)) + db.Delete(headerKey(number, hash)) } // DeleteBody removes all block body data associated with a hash. func DeleteBody(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(append(bodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...)) + db.Delete(BlockBodyKey(number, hash)) } // DeleteTd removes all block total difficulty data associated with a hash. func DeleteTd(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...), tdSuffix...)) + db.Delete(headerTDKey(number, hash)) } // DeleteBlock removes all block data associated with a hash. @@ -545,84 +394,7 @@ func DeleteBlock(db DatabaseDeleter, hash common.Hash, number uint64) { // DeleteBlockReceipts removes all receipt data associated with a block hash. func DeleteBlockReceipts(db DatabaseDeleter, hash common.Hash, number uint64) { - db.Delete(append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...)) -} - -// DeleteTxLookupEntry removes all transaction data associated with a hash. -func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) { - db.Delete(append(lookupPrefix, hash.Bytes()...)) -} - -// PreimageTable returns a Database instance with the key prefix for preimage entries. -func PreimageTable(db ethdb.Database) ethdb.Database { - return rawdb.NewTable(db, preimagePrefix) -} - -// WritePreimages writes the provided set of preimages to the database. `number` is the -// current block number, and is used for debug messages only. -func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash][]byte) error { - table := PreimageTable(db) - batch := table.NewBatch() - hitCount := 0 - for hash, preimage := range preimages { - if _, err := table.Get(hash.Bytes()); err != nil { - batch.Put(hash.Bytes(), preimage) - hitCount++ - } - } - preimageCounter.Inc(int64(len(preimages))) - preimageHitCounter.Inc(int64(hitCount)) - if hitCount > 0 { - if err := batch.Write(); err != nil { - return fmt.Errorf("preimage write fail for block %d: %v", number, err) - } - } - return nil -} - -// GetBlockChainVersion reads the version number from db. -func GetBlockChainVersion(db DatabaseReader) int { - var vsn uint - enc, _ := db.Get([]byte("BlockchainVersion")) - rlp.DecodeBytes(enc, &vsn) - return int(vsn) -} - -// WriteBlockChainVersion writes vsn as the version number to db. -func WriteBlockChainVersion(db ethdb.KeyValueWriter, vsn int) { - enc, _ := rlp.EncodeToBytes(uint(vsn)) - db.Put([]byte("BlockchainVersion"), enc) -} - -// WriteChainConfig writes the chain config settings to the database. -func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.ChainConfig) error { - // short circuit and ignore if nil config. GetChainConfig - // will return a default. - if cfg == nil { - return nil - } - - jsonChainConfig, err := json.Marshal(cfg) - if err != nil { - return err - } - - return db.Put(append(configPrefix, hash[:]...), jsonChainConfig) -} - -// GetChainConfig will fetch the network settings based on the given hash. -func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) { - jsonChainConfig, _ := db.Get(append(configPrefix, hash[:]...)) - if len(jsonChainConfig) == 0 { - return nil, ErrChainConfigNotFound - } - - var config params.ChainConfig - if err := json.Unmarshal(jsonChainConfig, &config); err != nil { - return nil, err - } - - return &config, nil + db.Delete(blockReceiptsKey(number, hash)) } // FindCommonAncestor returns the last common ancestor of two block headers diff --git a/core/database_util_test.go b/core/rawdb/accessors_chain_test.go similarity index 83% rename from core/database_util_test.go rename to core/rawdb/accessors_chain_test.go index f28ca160a5..c85bae2358 100644 --- a/core/database_util_test.go +++ b/core/rawdb/accessors_chain_test.go @@ -14,23 +14,27 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package core +package rawdb import ( "bytes" - "github.com/tomochain/tomochain/core/rawdb" + "encoding/hex" + "fmt" "math/big" "testing" + "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto/sha3" + "github.com/tomochain/tomochain/internal/blocktest" "github.com/tomochain/tomochain/rlp" ) // Tests block header storage and retrieval operations. func TestHeaderStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test header to move around the database and make sure it's really new header := &types.Header{Number: big.NewInt(42), Extra: []byte("test header")} @@ -65,7 +69,7 @@ func TestHeaderStorage(t *testing.T) { // Tests block body storage and retrieval operations. func TestBodyStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test body to move around the database and make sure it's really new body := &types.Body{Uncles: []*types.Header{{Extra: []byte("test header")}}} @@ -83,7 +87,7 @@ func TestBodyStorage(t *testing.T) { } if entry := GetBody(db, hash, 0); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(types.Transactions(body.Transactions)) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), blocktest.NewHasher()) != types.DeriveSha(types.Transactions(body.Transactions), blocktest.NewHasher()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(body.Uncles) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, body) } if entry := GetBodyRLP(db, hash, 0); entry == nil { @@ -105,7 +109,7 @@ func TestBodyStorage(t *testing.T) { // Tests block storage and retrieval operations. func TestBlockStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test block to move around the database and make sure it's really new block := types.NewBlockWithHeader(&types.Header{ @@ -139,7 +143,7 @@ func TestBlockStorage(t *testing.T) { } if entry := GetBody(db, block.Hash(), block.NumberU64()); entry == nil { t.Fatalf("Stored body not found") - } else if types.DeriveSha(types.Transactions(entry.Transactions)) != types.DeriveSha(block.Transactions()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { + } else if types.DeriveSha(types.Transactions(entry.Transactions), blocktest.NewHasher()) != types.DeriveSha(block.Transactions(), blocktest.NewHasher()) || types.CalcUncleHash(entry.Uncles) != types.CalcUncleHash(block.Uncles()) { t.Fatalf("Retrieved body mismatch: have %v, want %v", entry, block.Body()) } // Delete the block and verify the execution @@ -157,7 +161,7 @@ func TestBlockStorage(t *testing.T) { // Tests that partial block contents don't get reassembled into full blocks. func TestPartialBlockStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() block := types.NewBlockWithHeader(&types.Header{ Extra: []byte("test block"), UncleHash: types.EmptyUncleHash, @@ -198,7 +202,7 @@ func TestPartialBlockStorage(t *testing.T) { // Tests block total difficulty storage and retrieval operations. func TestTdStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test TD to move around the database and make sure it's really new hash, td := common.Hash{}, big.NewInt(314) @@ -223,7 +227,7 @@ func TestTdStorage(t *testing.T) { // Tests that canonical numbers can be mapped to hashes and retrieved. func TestCanonicalMappingStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() // Create a test canonical number and assinged hash to move around hash, number := common.Hash{0: 0xff}, uint64(314) @@ -248,7 +252,7 @@ func TestCanonicalMappingStorage(t *testing.T) { // Tests that head headers and head blocks can be assigned, individually. func TestHeadStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() blockHead := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block header")}) blockFull := types.NewBlockWithHeader(&types.Header{Extra: []byte("test block full")}) @@ -288,14 +292,14 @@ func TestHeadStorage(t *testing.T) { // Tests that positional lookup metadata can be stored and retrieved. func TestLookupStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() tx1 := types.NewTransaction(1, common.BytesToAddress([]byte{0x11}), big.NewInt(111), 1111, big.NewInt(11111), []byte{0x11, 0x11, 0x11}) tx2 := types.NewTransaction(2, common.BytesToAddress([]byte{0x22}), big.NewInt(222), 2222, big.NewInt(22222), []byte{0x22, 0x22, 0x22}) tx3 := types.NewTransaction(3, common.BytesToAddress([]byte{0x33}), big.NewInt(333), 3333, big.NewInt(33333), []byte{0x33, 0x33, 0x33}) txs := []*types.Transaction{tx1, tx2, tx3} - block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil) + block := types.NewBlock(&types.Header{Number: big.NewInt(314)}, txs, nil, nil, blocktest.NewHasher()) // Check that no transactions entries are in a pristine database for i, tx := range txs { @@ -333,8 +337,15 @@ func TestLookupStorage(t *testing.T) { // Tests that receipts associated with a single block can be stored and retrieved. func TestBlockReceiptStorage(t *testing.T) { - db := rawdb.NewMemoryDatabase() + db := NewMemoryDatabase() + + // Create a live block since we need metadata to reconstruct the receipt + tx1 := types.NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) + tx2 := types.NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil) + + body := &types.Body{Transactions: types.Transactions{tx1, tx2}} + // Create the two receipts to manage afterwards receipt1 := &types.Receipt{ Status: types.ReceiptStatusFailed, CumulativeGasUsed: 1, @@ -342,10 +353,12 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x11})}, {Address: common.BytesToAddress([]byte{0x01, 0x11})}, }, - TxHash: common.BytesToHash([]byte{0x11, 0x11}), + TxHash: tx1.Hash(), ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), GasUsed: 111111, } + receipt1.Bloom = types.CreateBloom(types.Receipts{receipt1}) + receipt2 := &types.Receipt{ PostState: common.Hash{2}.Bytes(), CumulativeGasUsed: 2, @@ -353,36 +366,64 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x22})}, {Address: common.BytesToAddress([]byte{0x02, 0x22})}, }, - TxHash: common.BytesToHash([]byte{0x22, 0x22}), + TxHash: tx2.Hash(), ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}), GasUsed: 222222, } + receipt2.Bloom = types.CreateBloom(types.Receipts{receipt2}) receipts := []*types.Receipt{receipt1, receipt2} // Check that no receipt entries are in a pristine database hash := common.BytesToHash([]byte{0x03, 0x14}) - if rs := GetBlockReceipts(db, hash, 0); len(rs) != 0 { + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { t.Fatalf("non existent receipts returned: %v", rs) } + // Insert the body that corresponds to the receipts + WriteBody(db, hash, 0, body) + // Insert the receipt slice into the database and check presence - if err := WriteBlockReceipts(db, hash, 0, receipts); err != nil { - t.Fatalf("failed to write block receipts: %v", err) - } - if rs := GetBlockReceipts(db, hash, 0); len(rs) == 0 { + WriteBlockReceipts(db, hash, 0, receipts) + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) == 0 { t.Fatalf("no receipts returned") } else { - for i := 0; i < len(receipts); i++ { - rlpHave, _ := rlp.EncodeToBytes(rs[i]) - rlpWant, _ := rlp.EncodeToBytes(receipts[i]) - - if !bytes.Equal(rlpHave, rlpWant) { - t.Fatalf("receipt #%d: receipt mismatch: have %v, want %v", i, rs[i], receipts[i]) - } + if err := checkReceiptsRLP(rs, receipts); err != nil { + t.Fatalf(err.Error()) } } - // Delete the receipt slice and check purge + // Delete the body and ensure that the receipts are no longer returned (metadata can't be recomputed) + DeleteBody(db, hash, 0) + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); rs != nil { + t.Fatalf("receipts returned when body was deleted: %v", rs) + } + // Ensure that receipts without metadata can be returned without the block body too + if err := checkReceiptsRLP(ReadRawReceipts(db, hash, 0), receipts); err != nil { + t.Fatalf(err.Error()) + } + // Sanity check that body alone without the receipt is a full purge + WriteBody(db, hash, 0, body) + DeleteBlockReceipts(db, hash, 0) - if rs := GetBlockReceipts(db, hash, 0); len(rs) != 0 { + if rs := GetBlockReceipts(db, hash, 0, params.TestChainConfig); len(rs) != 0 { t.Fatalf("deleted receipts returned: %v", rs) } } + +func checkReceiptsRLP(have, want types.Receipts) error { + if len(have) != len(want) { + return fmt.Errorf("receipts sizes mismatch: have %d, want %d", len(have), len(want)) + } + for i := 0; i < len(want); i++ { + rlpHave, err := rlp.EncodeToBytes(have[i]) + if err != nil { + return err + } + rlpWant, err := rlp.EncodeToBytes(want[i]) + if err != nil { + return err + } + if !bytes.Equal(rlpHave, rlpWant) { + return fmt.Errorf("receipt #%d: receipt mismatch: have %s, want %s", i, hex.EncodeToString(rlpHave), hex.EncodeToString(rlpWant)) + } + } + return nil +} diff --git a/core/rawdb/accessors_indexes.go b/core/rawdb/accessors_indexes.go new file mode 100644 index 0000000000..0bb54d65b6 --- /dev/null +++ b/core/rawdb/accessors_indexes.go @@ -0,0 +1,145 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/rlp" +) + +// GetTxLookupEntry retrieves the positional metadata associated with a transaction +// hash to allow retrieving the transaction or receipt by hash. +func GetTxLookupEntry(db DatabaseReader, hash common.Hash) (common.Hash, uint64, uint64) { + // Load the positional metadata from disk and bail if it fails + data, _ := db.Get(txLookupKey(hash)) + if len(data) == 0 { + return common.Hash{}, 0, 0 + } + // Parse and return the contents of the lookup entry + var entry TxLookupEntry + if err := rlp.DecodeBytes(data, &entry); err != nil { + log.Error("Invalid lookup entry RLP", "hash", hash, "err", err) + return common.Hash{}, 0, 0 + } + return entry.BlockHash, entry.BlockIndex, entry.Index +} + +// WriteTxLookupEntries stores a positional metadata for every transaction from +// a block, enabling hash based transaction and receipt lookups. +func WriteTxLookupEntries(db ethdb.KeyValueWriter, block *types.Block) error { + // Iterate over each transaction and encode its metadata + for i, tx := range block.Transactions() { + entry := TxLookupEntry{ + BlockHash: block.Hash(), + BlockIndex: block.NumberU64(), + Index: uint64(i), + } + data, err := rlp.EncodeToBytes(entry) + if err != nil { + return err + } + if err := db.Put(txLookupKey(tx.Hash()), data); err != nil { + return err + } + } + return nil +} + +// DeleteTxLookupEntry removes all transaction data associated with a hash. +func DeleteTxLookupEntry(db DatabaseDeleter, hash common.Hash) { + db.Delete(txLookupKey(hash)) +} + +// GetTransaction retrieves a specific transaction from the database, along with +// its added positional metadata. +func GetTransaction(db DatabaseReader, hash common.Hash) (*types.Transaction, common.Hash, uint64, uint64) { + // Retrieve the lookup metadata and resolve the transaction from the body + blockHash, blockNumber, txIndex := GetTxLookupEntry(db, hash) + + if blockHash != (common.Hash{}) { + body := GetBody(db, blockHash, blockNumber) + if body == nil || len(body.Transactions) <= int(txIndex) { + log.Error("Transaction referenced missing", "number", blockNumber, "hash", blockHash, "index", txIndex) + return nil, common.Hash{}, 0, 0 + } + return body.Transactions[txIndex], blockHash, blockNumber, txIndex + } + // Old transaction representation, load the transaction and its metadata separately + data, _ := db.Get(hash.Bytes()) + if len(data) == 0 { + return nil, common.Hash{}, 0, 0 + } + var tx types.Transaction + if err := rlp.DecodeBytes(data, &tx); err != nil { + return nil, common.Hash{}, 0, 0 + } + // Retrieve the blockchain positional metadata + data, _ = db.Get(oldTxMetaKey(hash)) + if len(data) == 0 { + return nil, common.Hash{}, 0, 0 + } + var entry TxLookupEntry + if err := rlp.DecodeBytes(data, &entry); err != nil { + return nil, common.Hash{}, 0, 0 + } + return &tx, entry.BlockHash, entry.BlockIndex, entry.Index +} + +// GetReceipt retrieves a specific transaction receipt from the database, along with +// its added positional metadata. +func GetReceipt(db DatabaseReader, hash common.Hash, config *params.ChainConfig) (*types.Receipt, common.Hash, uint64, uint64) { + // Retrieve the lookup metadata and resolve the receipt from the receipts + blockHash, blockNumber, receiptIndex := GetTxLookupEntry(db, hash) + + if blockHash != (common.Hash{}) { + receipts := GetBlockReceipts(db, blockHash, blockNumber, config) + if len(receipts) <= int(receiptIndex) { + log.Error("Receipt refereced missing", "number", blockNumber, "hash", blockHash, "index", receiptIndex) + return nil, common.Hash{}, 0, 0 + } + return receipts[receiptIndex], blockHash, blockNumber, receiptIndex + } + // Old receipt representation, load the receipt and set an unknown metadata + data, _ := db.Get(append(oldReceiptsPrefix, hash[:]...)) + if len(data) == 0 { + return nil, common.Hash{}, 0, 0 + } + var receipt types.ReceiptForStorage + err := rlp.DecodeBytes(data, &receipt) + if err != nil { + log.Error("Invalid receipt RLP", "hash", hash, "err", err) + } + return (*types.Receipt)(&receipt), common.Hash{}, 0, 0 +} + +// GetBloomBits retrieves the compressed bloom bit vector belonging to the given +// bit index and section indexes. +func GetBloomBits(db DatabaseReader, bit uint, section uint64, head common.Hash) ([]byte, error) { + return db.Get(bloomBitsKey(bit, section, head)) +} + +// WriteBloomBits writes the compressed bloom bits vector belonging to the given +// section and bit index. +func WriteBloomBits(db ethdb.KeyValueWriter, bit uint, section uint64, head common.Hash, bits []byte) { + if err := db.Put(bloomBitsKey(bit, section, head), bits); err != nil { + log.Crit("Failed to store bloom bits", "err", err) + } +} diff --git a/core/rawdb/accessors_metadata.go b/core/rawdb/accessors_metadata.go new file mode 100644 index 0000000000..16fbbd77b0 --- /dev/null +++ b/core/rawdb/accessors_metadata.go @@ -0,0 +1,71 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "encoding/json" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/rlp" +) + +// GetBlockChainVersion reads the version number from db. +func GetBlockChainVersion(db DatabaseReader) int { + var vsn uint + enc, _ := db.Get([]byte("BlockchainVersion")) + rlp.DecodeBytes(enc, &vsn) + return int(vsn) +} + +// WriteBlockChainVersion writes vsn as the version number to db. +func WriteBlockChainVersion(db ethdb.KeyValueWriter, vsn int) { + enc, _ := rlp.EncodeToBytes(uint(vsn)) + db.Put([]byte("BlockchainVersion"), enc) +} + +// WriteChainConfig writes the chain config settings to the database. +func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.ChainConfig) error { + // short circuit and ignore if nil config. GetChainConfig + // will return a default. + if cfg == nil { + return nil + } + + jsonChainConfig, err := json.Marshal(cfg) + if err != nil { + return err + } + + return db.Put(configKey(hash), jsonChainConfig) +} + +// GetChainConfig will fetch the network settings based on the given hash. +func GetChainConfig(db DatabaseReader, hash common.Hash) (*params.ChainConfig, error) { + jsonChainConfig, _ := db.Get(configKey(hash)) + if len(jsonChainConfig) == 0 { + return nil, ErrChainConfigNotFound + } + + var config params.ChainConfig + if err := json.Unmarshal(jsonChainConfig, &config); err != nil { + return nil, err + } + + return &config, nil +} diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go new file mode 100644 index 0000000000..6ef285019b --- /dev/null +++ b/core/rawdb/accessors_snapshot.go @@ -0,0 +1,135 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" +) + +// ReadSnapshotRoot retrieves the root of the block whose state is contained in +// the persisted snapshot. +func ReadSnapshotRoot(db ethdb.KeyValueReader) common.Hash { + data, _ := db.Get(snapshotRootKey) + if len(data) != common.HashLength { + return common.Hash{} + } + return common.BytesToHash(data) +} + +// WriteSnapshotRoot stores the root of the block whose state is contained in +// the persisted snapshot. +func WriteSnapshotRoot(db ethdb.KeyValueWriter, root common.Hash) { + if err := db.Put(snapshotRootKey, root[:]); err != nil { + log.Crit("Failed to store snapshot root", "err", err) + } +} + +// DeleteSnapshotRoot deletes the hash of the block whose state is contained in +// the persisted snapshot. Since snapshots are not immutable, this method can +// be used during updates, so a crash or failure will mark the entire snapshot +// invalid. +func DeleteSnapshotRoot(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotRootKey); err != nil { + log.Crit("Failed to remove snapshot root", "err", err) + } +} + +// ReadAccountSnapshot retrieves the snapshot entry of an account trie leaf. +func ReadAccountSnapshot(db ethdb.KeyValueReader, hash common.Hash) []byte { + data, _ := db.Get(accountSnapshotKey(hash)) + return data +} + +// WriteAccountSnapshot stores the snapshot entry of an account trie leaf. +func WriteAccountSnapshot(db ethdb.KeyValueWriter, hash common.Hash, entry []byte) { + if err := db.Put(accountSnapshotKey(hash), entry); err != nil { + log.Crit("Failed to store account snapshot", "err", err) + } +} + +// DeleteAccountSnapshot removes the snapshot entry of an account trie leaf. +func DeleteAccountSnapshot(db ethdb.KeyValueWriter, hash common.Hash) { + if err := db.Delete(accountSnapshotKey(hash)); err != nil { + log.Crit("Failed to delete account snapshot", "err", err) + } +} + +// ReadStorageSnapshot retrieves the snapshot entry of an storage trie leaf. +func ReadStorageSnapshot(db ethdb.KeyValueReader, accountHash, storageHash common.Hash) []byte { + data, _ := db.Get(storageSnapshotKey(accountHash, storageHash)) + return data +} + +// WriteStorageSnapshot stores the snapshot entry of an storage trie leaf. +func WriteStorageSnapshot(db ethdb.KeyValueWriter, accountHash, storageHash common.Hash, entry []byte) { + if err := db.Put(storageSnapshotKey(accountHash, storageHash), entry); err != nil { + log.Crit("Failed to store storage snapshot", "err", err) + } +} + +// DeleteStorageSnapshot removes the snapshot entry of an storage trie leaf. +func DeleteStorageSnapshot(db ethdb.KeyValueWriter, accountHash, storageHash common.Hash) { + if err := db.Delete(storageSnapshotKey(accountHash, storageHash)); err != nil { + log.Crit("Failed to delete storage snapshot", "err", err) + } +} + +// IterateStorageSnapshots returns an iterator for walking the entire storage +// space of a specific account. +func IterateStorageSnapshots(db ethdb.Iteratee, accountHash common.Hash) ethdb.Iterator { + return NewKeyLengthIterator(db.NewIterator(storageSnapshotsKey(accountHash), nil), len(SnapshotStoragePrefix)+2*common.HashLength) +} + +// ReadSnapshotJournal retrieves the serialized in-memory diff layers saved at +// the last shutdown. The blob is expected to be max a few 10s of megabytes. +func ReadSnapshotJournal(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(snapshotJournalKey) + return data +} + +// WriteSnapshotJournal stores the serialized in-memory diff layers to save at +// shutdown. The blob is expected to be max a few 10s of megabytes. +func WriteSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { + if err := db.Put(snapshotJournalKey, journal); err != nil { + log.Crit("Failed to store snapshot journal", "err", err) + } +} + +// DeleteSnapshotJournal deletes the serialized in-memory diff layers saved at +// the last shutdown +func DeleteSnapshotJournal(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotJournalKey); err != nil { + log.Crit("Failed to remove snapshot journal", "err", err) + } +} + +// ReadSnapshotGenerator retrieves the serialized snapshot generator saved at +// the last shutdown. +func ReadSnapshotGenerator(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(snapshotGeneratorKey) + return data +} + +// WriteSnapshotGenerator stores the serialized snapshot generator to save at +// shutdown. +func WriteSnapshotGenerator(db ethdb.KeyValueWriter, generator []byte) { + if err := db.Put(snapshotGeneratorKey, generator); err != nil { + log.Crit("Failed to store snapshot generator", "err", err) + } +} diff --git a/core/rawdb/accessors_state.go b/core/rawdb/accessors_state.go new file mode 100644 index 0000000000..28bba40f3c --- /dev/null +++ b/core/rawdb/accessors_state.go @@ -0,0 +1,58 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import ( + "fmt" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" +) + +// PreimageTable returns a Database instance with the key prefix for preimage entries. +func PreimageTable(db ethdb.Database) ethdb.Database { + return NewTable(db, PreimagePrefix) +} + +// ReadPreimage retrieves a single preimage of the provided hash. +func ReadPreimage(db ethdb.Database, hash common.Hash) []byte { + table := PreimageTable(db) + data, _ := table.Get(hash.Bytes()) + return data +} + +// WritePreimages writes the provided set of preimages to the database. `number` is the +// current block number, and is used for debug messages only. +func WritePreimages(db ethdb.Database, number uint64, preimages map[common.Hash][]byte) error { + table := PreimageTable(db) + batch := table.NewBatch() + hitCount := 0 + for hash, preimage := range preimages { + if _, err := table.Get(hash.Bytes()); err != nil { + batch.Put(hash.Bytes(), preimage) + hitCount++ + } + } + preimageCounter.Inc(int64(len(preimages))) + preimageHitCounter.Inc(int64(hitCount)) + if hitCount > 0 { + if err := batch.Write(); err != nil { + return fmt.Errorf("preimage write fail for block %d: %v", number, err) + } + } + return nil +} diff --git a/core/rawdb/accessors_trie.go b/core/rawdb/accessors_trie.go new file mode 100644 index 0000000000..7e1bbcaa2f --- /dev/null +++ b/core/rawdb/accessors_trie.go @@ -0,0 +1,64 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see + +package rawdb + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" +) + +// HashScheme is the legacy hash-based state scheme with which trie nodes are +// stored in the disk with node hash as the database key. The advantage of this +// scheme is that different versions of trie nodes can be stored in disk, which +// is very beneficial for constructing archive nodes. The drawback is it will +// store different trie nodes on the same path to different locations on the disk +// with no data locality, and it's unfriendly for designing state pruning. +// +// Now this scheme is still kept for backward compatibility, and it will be used +// for archive node and some other tries(e.g. light trie). +const HashScheme = "hashScheme" + +// ReadLegacyTrieNode retrieves the legacy trie node with the given +// associated node hash. +func ReadLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash) []byte { + data, err := db.Get(hash.Bytes()) + if err != nil { + return nil + } + return data +} + +// HasLegacyTrieNode checks if the trie node with the provided hash is present in db. +func HasLegacyTrieNode(db ethdb.KeyValueReader, hash common.Hash) bool { + ok, _ := db.Has(hash.Bytes()) + return ok +} + +// WriteLegacyTrieNode writes the provided legacy trie node to database. +func WriteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash, node []byte) { + if err := db.Put(hash.Bytes(), node); err != nil { + log.Crit("Failed to store legacy trie node", "err", err) + } +} + +// DeleteLegacyTrieNode deletes the specified legacy trie node from database. +func DeleteLegacyTrieNode(db ethdb.KeyValueWriter, hash common.Hash) { + if err := db.Delete(hash.Bytes()); err != nil { + log.Crit("Failed to delete legacy trie node", "err", err) + } +} diff --git a/core/rawdb/database.go b/core/rawdb/database.go index 1183a74f51..ea1dfe2347 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -17,10 +17,17 @@ package rawdb import ( + "bytes" "fmt" + "os" + "time" + + "github.com/olekukonko/tablewriter" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/ethdb/leveldb" "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/log" ) // freezerdb is a database wrapper that enabled freezer data retrievals. @@ -108,3 +115,177 @@ func NewLevelDBDatabase(file string, cache int, handles int, namespace string) ( } return NewDatabase(db), nil } + +type counter uint64 + +func (c counter) String() string { + return fmt.Sprintf("%d", c) +} + +func (c counter) Percentage(current uint64) string { + return fmt.Sprintf("%d", current*100/uint64(c)) +} + +// stat stores sizes and count for a parameter +type stat struct { + size common.StorageSize + count counter +} + +// Add size to the stat and increase the counter by 1 +func (s *stat) Add(size common.StorageSize) { + s.size += size + s.count++ +} + +func (s *stat) Size() string { + return s.size.String() +} + +func (s *stat) Count() string { + return s.count.String() +} + +// InspectDatabase traverses the entire database and checks the size +// of all different categories of data. +func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { + it := db.NewIterator(keyPrefix, keyStart) + defer it.Release() + + var ( + count int64 + start = time.Now() + logged = time.Now() + + // Key-value store statistics + headers stat + bodies stat + receipts stat + tds stat + numHashPairings stat + hashNumPairings stat + tries stat + codes stat + txLookups stat + accountSnaps stat + storageSnaps stat + preimages stat + bloomBits stat + cliqueSnaps stat + + // Ancient store statistics + ancientHeadersSize common.StorageSize + ancientBodiesSize common.StorageSize + ancientReceiptsSize common.StorageSize + ancientTdsSize common.StorageSize + ancientHashesSize common.StorageSize + + // Les statistic + chtTrieNodes stat + bloomTrieNodes stat + + // Meta- and unaccounted data + metadata stat + unaccounted stat + + // Totals + total common.StorageSize + ) + // Inspect key-value database first. + for it.Next() { + var ( + key = it.Key() + size = common.StorageSize(len(key) + len(it.Value())) + ) + total += size + switch { + case bytes.HasPrefix(key, headerPrefix) && len(key) == (len(headerPrefix)+8+common.HashLength): + headers.Add(size) + case bytes.HasPrefix(key, blockBodyPrefix) && len(key) == (len(blockBodyPrefix)+8+common.HashLength): + bodies.Add(size) + case bytes.HasPrefix(key, blockReceiptsPrefix) && len(key) == (len(blockReceiptsPrefix)+8+common.HashLength): + receipts.Add(size) + case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerTDSuffix): + tds.Add(size) + case bytes.HasPrefix(key, headerPrefix) && bytes.HasSuffix(key, headerHashSuffix): + numHashPairings.Add(size) + case bytes.HasPrefix(key, headerNumberPrefix) && len(key) == (len(headerNumberPrefix)+common.HashLength): + hashNumPairings.Add(size) + case len(key) == common.HashLength: + tries.Add(size) + case bytes.HasPrefix(key, txLookupPrefix) && len(key) == (len(txLookupPrefix)+common.HashLength): + txLookups.Add(size) + case bytes.HasPrefix(key, SnapshotAccountPrefix) && len(key) == (len(SnapshotAccountPrefix)+common.HashLength): + accountSnaps.Add(size) + case bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == (len(SnapshotStoragePrefix)+2*common.HashLength): + storageSnaps.Add(size) + case bytes.HasPrefix(key, []byte(PreimagePrefix)) && len(key) == (len(PreimagePrefix)+common.HashLength): + preimages.Add(size) + case bytes.HasPrefix(key, bloomBitsPrefix) && len(key) == (len(bloomBitsPrefix)+10+common.HashLength): + bloomBits.Add(size) + case bytes.HasPrefix(key, []byte("clique-")) && len(key) == 7+common.HashLength: + cliqueSnaps.Add(size) + case bytes.HasPrefix(key, []byte("cht-")) && len(key) == 4+common.HashLength: + chtTrieNodes.Add(size) + case bytes.HasPrefix(key, []byte("blt-")) && len(key) == 4+common.HashLength: + bloomTrieNodes.Add(size) + default: + var accounted bool + for _, meta := range [][]byte{databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey} { + if bytes.Equal(key, meta) { + metadata.Add(size) + accounted = true + break + } + } + if !accounted { + unaccounted.Add(size) + } + } + count += 1 + if count%1000 == 0 && time.Since(logged) > 8*time.Second { + log.Info("Inspecting database", "count", count, "elapsed", common.PrettyDuration(time.Since(start))) + logged = time.Now() + } + } + // Get number of ancient rows inside the freezer + ancients := counter(0) + if count, err := db.Ancients(); err == nil { + ancients = counter(count) + } + // Display the database statistic. + stats := [][]string{ + {"Key-Value store", "Headers", headers.Size(), headers.Count()}, + {"Key-Value store", "Bodies", bodies.Size(), bodies.Count()}, + {"Key-Value store", "Receipt lists", receipts.Size(), receipts.Count()}, + {"Key-Value store", "Difficulties", tds.Size(), tds.Count()}, + {"Key-Value store", "Block number->hash", numHashPairings.Size(), numHashPairings.Count()}, + {"Key-Value store", "Block hash->number", hashNumPairings.Size(), hashNumPairings.Count()}, + {"Key-Value store", "Transaction index", txLookups.Size(), txLookups.Count()}, + {"Key-Value store", "Bloombit index", bloomBits.Size(), bloomBits.Count()}, + {"Key-Value store", "Contract codes", codes.Size(), codes.Count()}, + {"Key-Value store", "Trie nodes", tries.Size(), tries.Count()}, + {"Key-Value store", "Trie preimages", preimages.Size(), preimages.Count()}, + {"Key-Value store", "Account snapshot", accountSnaps.Size(), accountSnaps.Count()}, + {"Key-Value store", "Storage snapshot", storageSnaps.Size(), storageSnaps.Count()}, + {"Key-Value store", "Clique snapshots", cliqueSnaps.Size(), cliqueSnaps.Count()}, + {"Key-Value store", "Singleton metadata", metadata.Size(), metadata.Count()}, + {"Ancient store", "Headers", ancientHeadersSize.String(), ancients.String()}, + {"Ancient store", "Bodies", ancientBodiesSize.String(), ancients.String()}, + {"Ancient store", "Receipt lists", ancientReceiptsSize.String(), ancients.String()}, + {"Ancient store", "Difficulties", ancientTdsSize.String(), ancients.String()}, + {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()}, + {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, + {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, + } + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Database", "Category", "Size", "Items"}) + table.SetFooter([]string{"", "Total", total.String(), " "}) + table.AppendBulk(stats) + table.Render() + + if unaccounted.size > 0 { + log.Error("Database contains unaccounted data", "size", unaccounted.size, "count", unaccounted.count) + } + return nil +} diff --git a/core/rawdb/key_length_iterator.go b/core/rawdb/key_length_iterator.go new file mode 100644 index 0000000000..9e24f0ec32 --- /dev/null +++ b/core/rawdb/key_length_iterator.go @@ -0,0 +1,47 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rawdb + +import "github.com/tomochain/tomochain/ethdb" + +// KeyLengthIterator is a wrapper for a database iterator that ensures only key-value pairs +// with a specific key length will be returned. +type KeyLengthIterator struct { + requiredKeyLength int + ethdb.Iterator +} + +// NewKeyLengthIterator returns a wrapped version of the iterator that will only return key-value +// pairs where keys with a specific key length will be returned. +func NewKeyLengthIterator(it ethdb.Iterator, keyLen int) ethdb.Iterator { + return &KeyLengthIterator{ + Iterator: it, + requiredKeyLength: keyLen, + } +} + +func (it *KeyLengthIterator) Next() bool { + // Return true as soon as a key with the required key length is discovered + for it.Iterator.Next() { + if len(it.Iterator.Key()) == it.requiredKeyLength { + return true + } + } + + // Return false when we exhaust the keys in the underlying iterator. + return false +} diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go new file mode 100644 index 0000000000..b49e238ab6 --- /dev/null +++ b/core/rawdb/schema.go @@ -0,0 +1,166 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package rawdb contains a collection of low level database accessors. +package rawdb + +import ( + "encoding/binary" + "errors" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/metrics" +) + +var ( + // databaseVersionKey tracks the current database version. + databaseVersionKey = []byte("DatabaseVersion") + + // headFastBlockKey tracks the latest known incomplete block's hash during fast sync. + headFastBlockKey = []byte("LastFast") + + // fastTrieProgressKey tracks the number of trie entries imported during fast sync. + fastTrieProgressKey = []byte("TrieSync") + + // snapshotRootKey tracks the hash of the last snapshot. + snapshotRootKey = []byte("SnapshotRoot") + + // snapshotJournalKey tracks the in-memory diff layers across restarts. + snapshotJournalKey = []byte("SnapshotJournal") + + // snapshotGeneratorKey tracks the snapshot generation marker across restarts. + snapshotGeneratorKey = []byte("SnapshotGenerator") + + headHeaderKey = []byte("LastHeader") + headBlockKey = []byte("LastBlock") + headFastKey = []byte("LastFast") + trieSyncKey = []byte("TrieSync") + + // Data item prefixes (use single byte to avoid mixing data types, avoid `i`). + headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header + headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td + headerHashSuffix = []byte("n") // headerPrefix + num (uint64 big endian) + headerHashSuffix -> hash + headerNumberPrefix = []byte("H") // headerNumberPrefix + hash -> num (uint64 big endian) + blockBodyPrefix = []byte("b") // blockBodyPrefix + num (uint64 big endian) + hash -> block body + blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts + txLookupPrefix = []byte("l") // txLookupPrefix + hash -> transaction/receipt lookup metadata + bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits + SnapshotAccountPrefix = []byte("a") // SnapshotAccountPrefix + account hash -> account trie value + SnapshotStoragePrefix = []byte("o") // SnapshotStoragePrefix + account hash + storage hash -> storage trie value + + PreimagePrefix = "secure-key-" // PreimagePrefix + hash -> preimage + configPrefix = []byte("ethereum-config-") // config prefix for the db + + // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress + BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress + + // used by old db, now only used for conversion + oldReceiptsPrefix = []byte("receipts-") + oldTxMetaSuffix = []byte{0x01} + + ErrChainConfigNotFound = errors.New("ChainConfig not found") // general config not found error + + preimageCounter = metrics.NewRegisteredCounter("db/preimage/total", nil) + preimageHitCounter = metrics.NewRegisteredCounter("db/preimage/hits", nil) +) + +// TxLookupEntry is a positional metadata to help looking up the data content of +// a transaction or receipt given only its hash. +type TxLookupEntry struct { + BlockHash common.Hash + BlockIndex uint64 + Index uint64 +} + +// configKey = configPrefix + hash +func configKey(hash common.Hash) []byte { + return append(configPrefix, hash.Bytes()...) +} + +// headerKey = headerPrefix + num (uint64 big endian) + hash +func headerKey(number uint64, hash common.Hash) []byte { + return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// headerTDKey = headerPrefix + num (uint64 big endian) + hash + headerTDSuffix +func headerTDKey(number uint64, hash common.Hash) []byte { + return append(HeaderKey(number, hash), headerTDSuffix...) +} + +// headerHashKey = headerPrefix + num (uint64 big endian) + headerHashSuffix +func headerHashKey(number uint64) []byte { + return append(append(headerPrefix, encodeBlockNumber(number)...), headerHashSuffix...) +} + +// HeaderKey = headerPrefix + num (uint64 big endian) + hash +func HeaderKey(number uint64, hash common.Hash) []byte { + return append(append(headerPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// headerNumberKey = headerNumberPrefix + hash +func headerNumberKey(hash common.Hash) []byte { + return append(headerNumberPrefix, hash.Bytes()...) +} + +// BlockBodyKey = blockBodyPrefix + num (uint64 big endian) + hash +func BlockBodyKey(number uint64, hash common.Hash) []byte { + return append(append(blockBodyPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// blockReceiptsKey = blockReceiptsPrefix + num (uint64 big endian) + hash +func blockReceiptsKey(number uint64, hash common.Hash) []byte { + return append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + +// txLookupKey = txLookupPrefix + hash +func txLookupKey(hash common.Hash) []byte { + return append(txLookupPrefix, hash.Bytes()...) +} + +// bloomBitsKey = bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash +func bloomBitsKey(bit uint, section uint64, hash common.Hash) []byte { + key := append(append(bloomBitsPrefix, make([]byte, 10)...), hash.Bytes()...) + + binary.BigEndian.PutUint16(key[1:], uint16(bit)) + binary.BigEndian.PutUint64(key[3:], section) + + return key +} + +// oldTxMetaKey = hash + oldTxMetaSuffix +func oldTxMetaKey(hash common.Hash) []byte { + return append(hash.Bytes(), oldTxMetaSuffix...) +} + +// oldReceiptsKey = oldReceiptsPrefix + hash +func oldReceiptsKey(hash common.Hash) []byte { + return append(oldReceiptsPrefix, hash[:]...) +} + +// accountSnapshotKey = SnapshotAccountPrefix + hash +func accountSnapshotKey(hash common.Hash) []byte { + return append(SnapshotAccountPrefix, hash.Bytes()...) +} + +// storageSnapshotKey = SnapshotStoragePrefix + account hash + storage hash +func storageSnapshotKey(accountHash, storageHash common.Hash) []byte { + return append(append(SnapshotStoragePrefix, accountHash.Bytes()...), storageHash.Bytes()...) +} + +// storageSnapshotsKey = SnapshotStoragePrefix + account hash + storage hash +func storageSnapshotsKey(accountHash common.Hash) []byte { + return append(SnapshotStoragePrefix, accountHash.Bytes()...) +} diff --git a/core/state/database.go b/core/state/database.go index b57f134db8..8f47c33965 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -21,6 +21,7 @@ import ( lru "github.com/hashicorp/golang-lru" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/trie" ) @@ -59,20 +60,40 @@ type Trie interface { // TODO(fjl): remove this when SecureTrie is removed GetKey([]byte) []byte - // TryGet returns the value for key stored in the trie. The value bytes must - // not be modified by the caller. If a node was not found in the database, a - // trie.MissingNodeError is returned. - TryGet(key []byte) ([]byte, error) - - // TryUpdate associates key with value in the trie. If value has length zero, any - // existing value is deleted from the trie. The value bytes must not be modified + // GetStorage returns the value for key stored in the trie. The value bytes + // must not be modified by the caller. If a node was not found in the database, + // a trie.MissingNodeError is returned. + GetStorage(addr common.Address, key []byte) ([]byte, error) + + // GetAccount abstracts an account read from the trie. It retrieves the + // account blob from the trie with provided account address and decodes it + // with associated decoding algorithm. If the specified account is not in + // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes + // are missing or the account blob is incorrect for decoding), an error will + // be returned. + GetAccount(address common.Address) (*types.StateAccount, error) + + // UpdateStorage associates key with value in the trie. If value has length zero, + // any existing value is deleted from the trie. The value bytes must not be modified // by the caller while they are stored in the trie. If a node was not found in the // database, a trie.MissingNodeError is returned. - TryUpdate(key, value []byte) error + UpdateStorage(addr common.Address, key, value []byte) error + + // UpdateAccount abstracts an account write to the trie. It encodes the + // provided account object with associated algorithm and then updates it + // in the trie with provided address. + UpdateAccount(address common.Address, account *types.StateAccount) error + + // UpdateContractCode abstracts code write to the trie. It is expected + // to be moved to the stateWriter interface when the latter is ready. + UpdateContractCode(address common.Address, codeHash common.Hash, code []byte) error + + // DeleteStorage removes any existing value for key from the trie. If a node + // was not found in the database, a trie.MissingNodeError is returned. + DeleteStorage(addr common.Address, key []byte) error - // TryDelete removes any existing value for key from the trie. If a node was not - // found in the database, a trie.MissingNodeError is returned. - TryDelete(key []byte) error + // DeleteAccount abstracts an account deletion from the trie. + DeleteAccount(address common.Address) error // Hash returns the root hash of the trie. It does not write to the database and // can be used even if the trie doesn't have one. @@ -98,18 +119,18 @@ type Trie interface { // NewDatabase creates a backing store for state. The returned database is safe for // concurrent use, but does not retain any recent trie nodes in memory. To keep some -// historical state in memory, use the NewDatabaseWithCache constructor. +// historical state in memory, use the NewDatabaseWithConfig constructor. func NewDatabase(db ethdb.Database) Database { - return NewDatabaseWithCache(db, 0) + return NewDatabaseWithConfig(db, nil) } -// NewDatabaseWithCache creates a backing store for state. The returned database +// NewDatabaseWithConfig creates a backing store for state. The returned database // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // large memory cache. -func NewDatabaseWithCache(db ethdb.Database, cache int) Database { +func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithCache(db, cache), + db: trie.NewDatabaseWithConfig(db, config), codeSizeCache: csc, } } diff --git a/core/state/dump.go b/core/state/dump.go index f08c6e7df3..6d8994462f 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -39,40 +40,40 @@ type Dump struct { Accounts map[string]DumpAccount `json:"accounts"` } -func (self *StateDB) RawDump() Dump { +func (s *StateDB) RawDump() Dump { dump := Dump{ - Root: fmt.Sprintf("%x", self.trie.Hash()), + Root: fmt.Sprintf("%x", s.trie.Hash()), Accounts: make(map[string]DumpAccount), } - it := trie.NewIterator(self.trie.NodeIterator(nil)) + it := trie.NewIterator(s.trie.NodeIterator(nil)) for it.Next() { - addr := self.trie.GetKey(it.Key) - var data Account + addr := s.trie.GetKey(it.Key) + var data types.StateAccount if err := rlp.DecodeBytes(it.Value, &data); err != nil { panic(err) } - obj := newObject(nil, common.BytesToAddress(addr), data, nil) + obj := newObject(nil, common.BytesToAddress(addr), &data) account := DumpAccount{ Balance: data.Balance.String(), Nonce: data.Nonce, Root: common.Bytes2Hex(data.Root[:]), CodeHash: common.Bytes2Hex(data.CodeHash), - Code: common.Bytes2Hex(obj.Code(self.db)), + Code: common.Bytes2Hex(obj.Code(s.db)), Storage: make(map[string]string), } - storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator(nil)) + storageIt := trie.NewIterator(obj.getTrie(s.db).NodeIterator(nil)) for storageIt.Next() { - account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) + account.Storage[common.Bytes2Hex(s.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) } dump.Accounts[common.Bytes2Hex(addr)] = account } return dump } -func (self *StateDB) Dump() []byte { - json, err := json.MarshalIndent(self.RawDump(), "", " ") +func (s *StateDB) Dump() []byte { + json, err := json.MarshalIndent(s.RawDump(), "", " ") if err != nil { fmt.Println("dump err", err) } diff --git a/core/state/iterator.go b/core/state/iterator.go index 3cfc592ecb..d69321f36a 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -104,7 +105,7 @@ func (it *NodeIterator) step() error { return nil } // Otherwise we've reached an account node, initiate data iteration - var account Account + var account types.StateAccount if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index 20864e0768..b5b7287025 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -29,7 +29,7 @@ func TestNodeIteratorCoverage(t *testing.T) { // Create some arbitrary test state to iterate db, root, _ := makeTestState() - state, err := New(root, db) + state, err := New(root, db, nil) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } diff --git a/core/state/journal.go b/core/state/journal.go index 1ac5cdbf25..2c75c9dbef 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -22,11 +22,67 @@ import ( "github.com/tomochain/tomochain/common" ) +// journalEntry is a modification entry in the state change journal that can be +// reverted on demand. type journalEntry interface { - undo(*StateDB) + // revert undoes the changes introduced by this journal entry. + revert(*StateDB) + + // dirtied returns the Ethereum address modified by this journal entry. + dirtied() *common.Address +} + +// journal contains the list of state modifications applied since the last state +// commit. These are tracked to be able to be reverted in case of an execution +// exception or revertal request. +type journal struct { + entries []journalEntry // Current changes tracked by the journal + dirties map[common.Address]int // Dirty accounts and the number of changes } -type journal []journalEntry +// newJournal create a new initialized journal. +func newJournal() *journal { + return &journal{ + dirties: make(map[common.Address]int), + } +} + +// append inserts a new modification entry to the end of the change journal. +func (j *journal) append(entry journalEntry) { + j.entries = append(j.entries, entry) + if addr := entry.dirtied(); addr != nil { + j.dirties[*addr]++ + } +} + +// revert undoes a batch of journalled modifications along with any reverted +// dirty handling too. +func (j *journal) revert(statedb *StateDB, snapshot int) { + for i := len(j.entries) - 1; i >= snapshot; i-- { + // Undo the changes made by the operation + j.entries[i].revert(statedb) + + // Drop any dirty tracking induced by the change + if addr := j.entries[i].dirtied(); addr != nil { + if j.dirties[*addr]--; j.dirties[*addr] == 0 { + delete(j.dirties, *addr) + } + } + } + j.entries = j.entries[:snapshot] +} + +// dirty explicitly sets an address to dirty, even if the change entries would +// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD +// precompile consensus exception. +func (j *journal) dirty(addr common.Address) { + j.dirties[addr]++ +} + +// length returns the current number of entries in the journal. +func (j *journal) length() int { + return len(j.entries) +} type ( // Changes to the account trie. @@ -34,7 +90,8 @@ type ( account *common.Address } resetObjectChange struct { - prev *stateObject + prev *stateObject + prevdestruct bool } suicideChange struct { account *common.Address @@ -77,16 +134,27 @@ type ( } ) -func (ch createObjectChange) undo(s *StateDB) { +func (ch createObjectChange) revert(s *StateDB) { delete(s.stateObjects, *ch.account) delete(s.stateObjectsDirty, *ch.account) } -func (ch resetObjectChange) undo(s *StateDB) { +func (ch createObjectChange) dirtied() *common.Address { + return ch.account +} + +func (ch resetObjectChange) revert(s *StateDB) { s.setStateObject(ch.prev) + if !ch.prevdestruct && s.snap != nil { + delete(s.snapDestructs, ch.prev.addrHash) + } +} + +func (ch resetObjectChange) dirtied() *common.Address { + return nil } -func (ch suicideChange) undo(s *StateDB) { +func (ch suicideChange) revert(s *StateDB) { obj := s.getStateObject(*ch.account) if obj != nil { obj.suicided = ch.prev @@ -94,38 +162,60 @@ func (ch suicideChange) undo(s *StateDB) { } } +func (ch suicideChange) dirtied() *common.Address { + return ch.account +} + var ripemd = common.HexToAddress("0000000000000000000000000000000000000003") -func (ch touchChange) undo(s *StateDB) { - if !ch.prev && *ch.account != ripemd { - s.getStateObject(*ch.account).touched = ch.prev - if !ch.prevDirty { - delete(s.stateObjectsDirty, *ch.account) - } - } +func (ch touchChange) revert(s *StateDB) { +} + +func (ch touchChange) dirtied() *common.Address { + return ch.account } -func (ch balanceChange) undo(s *StateDB) { +func (ch balanceChange) revert(s *StateDB) { s.getStateObject(*ch.account).setBalance(ch.prev) } -func (ch nonceChange) undo(s *StateDB) { +func (ch balanceChange) dirtied() *common.Address { + return ch.account +} + +func (ch nonceChange) revert(s *StateDB) { s.getStateObject(*ch.account).setNonce(ch.prev) } -func (ch codeChange) undo(s *StateDB) { +func (ch nonceChange) dirtied() *common.Address { + return ch.account +} + +func (ch codeChange) revert(s *StateDB) { s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) } -func (ch storageChange) undo(s *StateDB) { +func (ch codeChange) dirtied() *common.Address { + return ch.account +} + +func (ch storageChange) revert(s *StateDB) { s.getStateObject(*ch.account).setState(ch.key, ch.prevalue) } -func (ch refundChange) undo(s *StateDB) { +func (ch storageChange) dirtied() *common.Address { + return ch.account +} + +func (ch refundChange) revert(s *StateDB) { s.refund = ch.prev } -func (ch addLogChange) undo(s *StateDB) { +func (ch refundChange) dirtied() *common.Address { + return nil +} + +func (ch addLogChange) revert(s *StateDB) { logs := s.logs[ch.txhash] if len(logs) == 1 { delete(s.logs, ch.txhash) @@ -135,6 +225,14 @@ func (ch addLogChange) undo(s *StateDB) { s.logSize-- } -func (ch addPreimageChange) undo(s *StateDB) { +func (ch addLogChange) dirtied() *common.Address { + return nil +} + +func (ch addPreimageChange) revert(s *StateDB) { delete(s.preimages, ch.hash) } + +func (ch addPreimageChange) dirtied() *common.Address { + return nil +} diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index 79220dc077..c4fa4937aa 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -17,17 +17,17 @@ package state import ( - "github.com/tomochain/tomochain/core/rawdb" "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" ) var addr = common.BytesToAddress([]byte("test")) func create() (*ManagedState, *account) { db := rawdb.NewMemoryDatabase() - statedb, _ := New(common.Hash{}, NewDatabase(db)) + statedb, _ := New(common.Hash{}, NewDatabase(db), nil) ms := ManageState(statedb) ms.StateDB.SetNonce(addr, 100) ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr)) diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go new file mode 100644 index 0000000000..98214497c8 --- /dev/null +++ b/core/state/snapshot/difflayer.go @@ -0,0 +1,535 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "encoding/binary" + "fmt" + "math" + "math/rand" + "sort" + "sync" + "sync/atomic" + "time" + + bloomfilter "github.com/holiman/bloomfilter/v2" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/rlp" +) + +var ( + // aggregatorMemoryLimit is the maximum size of the bottom-most diff layer + // that aggregates the writes from above until it's flushed into the disk + // layer. + // + // Note, bumping this up might drastically increase the size of the bloom + // filters that's stored in every diff layer. Don't do that without fully + // understanding all the implications. + aggregatorMemoryLimit = uint64(4 * 1024 * 1024) + + // aggregatorItemLimit is an approximate number of items that will end up + // in the agregator layer before it's flushed out to disk. A plain account + // weighs around 14B (+hash), a storage slot 32B (+hash), a deleted slot + // 0B (+hash). Slots are mostly set/unset in lockstep, so thet average at + // 16B (+hash). All in all, the average entry seems to be 15+32=47B. Use a + // smaller number to be on the safe side. + aggregatorItemLimit = aggregatorMemoryLimit / 42 + + // bloomTargetError is the target false positive rate when the aggregator + // layer is at its fullest. The actual value will probably move around up + // and down from this number, it's mostly a ballpark figure. + // + // Note, dropping this down might drastically increase the size of the bloom + // filters that's stored in every diff layer. Don't do that without fully + // understanding all the implications. + bloomTargetError = 0.02 + + // bloomSize is the ideal bloom filter size given the maximum number of items + // it's expected to hold and the target false positive error rate. + bloomSize = math.Ceil(float64(aggregatorItemLimit) * math.Log(bloomTargetError) / math.Log(1/math.Pow(2, math.Log(2)))) + + // bloomFuncs is the ideal number of bits a single entry should set in the + // bloom filter to keep its size to a minimum (given it's size and maximum + // entry count). + bloomFuncs = math.Round((bloomSize / float64(aggregatorItemLimit)) * math.Log(2)) + + // the bloom offsets are runtime constants which determines which part of the + // the account/storage hash the hasher functions looks at, to determine the + // bloom key for an account/slot. This is randomized at init(), so that the + // global population of nodes do not all display the exact same behaviour with + // regards to bloom content + bloomDestructHasherOffset = 0 + bloomAccountHasherOffset = 0 + bloomStorageHasherOffset = 0 +) + +func init() { + // Init the bloom offsets in the range [0:24] (requires 8 bytes) + bloomDestructHasherOffset = rand.Intn(25) + bloomAccountHasherOffset = rand.Intn(25) + bloomStorageHasherOffset = rand.Intn(25) + + // The destruct and account blooms must be different, as the storage slots + // will check for destruction too for every bloom miss. It should not collide + // with modified accounts. + for bloomAccountHasherOffset == bloomDestructHasherOffset { + bloomAccountHasherOffset = rand.Intn(25) + } +} + +// diffLayer represents a collection of modifications made to a state snapshot +// after running a block on top. It contains one sorted list for the account trie +// and one-one list for each storage tries. +// +// The goal of a diff layer is to act as a journal, tracking recent modifications +// made to the state, that have not yet graduated into a semi-immutable state. +type diffLayer struct { + origin *diskLayer // Base disk layer to directly use on bloom misses + parent snapshot // Parent snapshot modified by this one, never nil + memory uint64 // Approximate guess as to how much memory we use + + root common.Hash // Root hash to which this snapshot diff belongs to + stale uint32 // Signals that the layer became stale (state progressed) + + destructSet map[common.Hash]struct{} // Keyed markers for deleted (and potentially) recreated accounts + accountList []common.Hash // List of account for iteration. If it exists, it's sorted, otherwise it's nil + accountData map[common.Hash][]byte // Keyed accounts for direct retrival (nil means deleted) + storageList map[common.Hash][]common.Hash // List of storage slots for iterated retrievals, one per account. Any existing lists are sorted if non-nil + storageData map[common.Hash]map[common.Hash][]byte // Keyed storage slots for direct retrival. one per account (nil means deleted) + + diffed *bloomfilter.Filter // Bloom filter tracking all the diffed items up to the disk layer + + lock sync.RWMutex +} + +// destructBloomHasher is a wrapper around a common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert a destruct +// event into a 64 bit mini hash. +type destructBloomHasher common.Hash + +func (h destructBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h destructBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h destructBloomHasher) Reset() { panic("not implemented") } +func (h destructBloomHasher) BlockSize() int { panic("not implemented") } +func (h destructBloomHasher) Size() int { return 8 } +func (h destructBloomHasher) Sum64() uint64 { + return binary.BigEndian.Uint64(h[bloomDestructHasherOffset : bloomDestructHasherOffset+8]) +} + +// accountBloomHasher is a wrapper around a common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert an account +// hash into a 64 bit mini hash. +type accountBloomHasher common.Hash + +func (h accountBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h accountBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h accountBloomHasher) Reset() { panic("not implemented") } +func (h accountBloomHasher) BlockSize() int { panic("not implemented") } +func (h accountBloomHasher) Size() int { return 8 } +func (h accountBloomHasher) Sum64() uint64 { + return binary.BigEndian.Uint64(h[bloomAccountHasherOffset : bloomAccountHasherOffset+8]) +} + +// storageBloomHasher is a wrapper around a [2]common.Hash to satisfy the interface +// API requirements of the bloom library used. It's used to convert an account +// hash into a 64 bit mini hash. +type storageBloomHasher [2]common.Hash + +func (h storageBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") } +func (h storageBloomHasher) Sum(b []byte) []byte { panic("not implemented") } +func (h storageBloomHasher) Reset() { panic("not implemented") } +func (h storageBloomHasher) BlockSize() int { panic("not implemented") } +func (h storageBloomHasher) Size() int { return 8 } +func (h storageBloomHasher) Sum64() uint64 { + return binary.BigEndian.Uint64(h[0][bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^ + binary.BigEndian.Uint64(h[1][bloomStorageHasherOffset:bloomStorageHasherOffset+8]) +} + +// newDiffLayer creates a new diff on top of an existing snapshot, whether that's a low +// level persistent database or a hierarchical diff already. +func newDiffLayer(parent snapshot, root common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { + // Create the new layer with some pre-allocated data segments + dl := &diffLayer{ + parent: parent, + root: root, + destructSet: destructs, + accountData: accounts, + storageData: storage, + } + switch parent := parent.(type) { + case *diskLayer: + dl.rebloom(parent) + case *diffLayer: + dl.rebloom(parent.origin) + default: + panic("unknown parent type") + } + // Sanity check that accounts or storage slots are never nil + for accountHash, blob := range accounts { + if blob == nil { + panic(fmt.Sprintf("account %#x nil", accountHash)) + } + } + for accountHash, slots := range storage { + if slots == nil { + panic(fmt.Sprintf("storage %#x nil", accountHash)) + } + } + // Determine memory size and track the dirty writes + for _, data := range accounts { + dl.memory += uint64(common.HashLength + len(data)) + snapshotDirtyAccountWriteMeter.Mark(int64(len(data))) + } + // Fill the storage hashes and sort them for the iterator + dl.storageList = make(map[common.Hash][]common.Hash) + for accountHash := range destructs { + dl.storageList[accountHash] = nil + } + // Determine memory size and track the dirty writes + for _, slots := range storage { + for _, data := range slots { + dl.memory += uint64(common.HashLength + len(data)) + snapshotDirtyStorageWriteMeter.Mark(int64(len(data))) + } + } + dl.memory += uint64(len(dl.storageList) * common.HashLength) + return dl +} + +// rebloom discards the layer's current bloom and rebuilds it from scratch based +// on the parent's and the local diffs. +func (dl *diffLayer) rebloom(origin *diskLayer) { + dl.lock.Lock() + defer dl.lock.Unlock() + + defer func(start time.Time) { + snapshotBloomIndexTimer.Update(time.Since(start)) + }(time.Now()) + + // Inject the new origin that triggered the rebloom + dl.origin = origin + + // Retrieve the parent bloom or create a fresh empty one + if parent, ok := dl.parent.(*diffLayer); ok { + parent.lock.RLock() + dl.diffed, _ = parent.diffed.Copy() + parent.lock.RUnlock() + } else { + dl.diffed, _ = bloomfilter.New(uint64(bloomSize), uint64(bloomFuncs)) + } + // Iterate over all the accounts and storage slots and index them + for hash := range dl.destructSet { + dl.diffed.Add(destructBloomHasher(hash)) + } + for hash := range dl.accountData { + dl.diffed.Add(accountBloomHasher(hash)) + } + for accountHash, slots := range dl.storageData { + for storageHash := range slots { + dl.diffed.Add(storageBloomHasher{accountHash, storageHash}) + } + } + // Calculate the current false positive rate and update the error rate meter. + // This is a bit cheating because subsequent layers will overwrite it, but it + // should be fine, we're only interested in ballpark figures. + k := float64(dl.diffed.K()) + n := float64(dl.diffed.N()) + m := float64(dl.diffed.M()) + snapshotBloomErrorGauge.Update(math.Pow(1.0-math.Exp((-k)*(n+0.5)/(m-1)), k)) +} + +// Root returns the root hash for which this snapshot was made. +func (dl *diffLayer) Root() common.Hash { + return dl.root +} + +// Parent returns the subsequent layer of a diff layer. +func (dl *diffLayer) Parent() snapshot { + return dl.parent +} + +// Stale return whether this layer has become stale (was flattened across) or if +// it's still live. +func (dl *diffLayer) Stale() bool { + return atomic.LoadUint32(&dl.stale) != 0 +} + +// Account directly retrieves the account associated with a particular hash in +// the snapshot slim data format. +func (dl *diffLayer) Account(hash common.Hash) (*types.SlimAccount, error) { + data, err := dl.AccountRLP(hash) + if err != nil { + return nil, err + } + if len(data) == 0 { // can be both nil and []byte{} + return nil, nil + } + account := new(types.SlimAccount) + if err := rlp.DecodeBytes(data, account); err != nil { + panic(err) + } + return account, nil +} + +// AccountRLP directly retrieves the account RLP associated with a particular +// hash in the snapshot slim data format. +func (dl *diffLayer) AccountRLP(hash common.Hash) ([]byte, error) { + // Check the bloom filter first whether there's even a point in reaching into + // all the maps in all the layers below + dl.lock.RLock() + hit := dl.diffed.Contains(accountBloomHasher(hash)) + if !hit { + hit = dl.diffed.Contains(destructBloomHasher(hash)) + } + dl.lock.RUnlock() + + // If the bloom filter misses, don't even bother with traversing the memory + // diff layers, reach straight into the bottom persistent disk layer + if !hit { + snapshotBloomAccountMissMeter.Mark(1) + return dl.origin.AccountRLP(hash) + } + // The bloom filter hit, start poking in the internal maps + return dl.accountRLP(hash, 0) +} + +// accountRLP is an internal version of AccountRLP that skips the bloom filter +// checks and uses the internal maps to try and retrieve the data. It's meant +// to be used if a higher layer's bloom filter hit already. +func (dl *diffLayer) accountRLP(hash common.Hash, depth int) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.Stale() { + return nil, ErrSnapshotStale + } + // If the account is known locally, return it + if data, ok := dl.accountData[hash]; ok { + snapshotDirtyAccountHitMeter.Mark(1) + snapshotDirtyAccountHitDepthHist.Update(int64(depth)) + snapshotDirtyAccountReadMeter.Mark(int64(len(data))) + snapshotBloomAccountTrueHitMeter.Mark(1) + return data, nil + } + // If the account is known locally, but deleted, return it + if _, ok := dl.destructSet[hash]; ok { + snapshotDirtyAccountHitMeter.Mark(1) + snapshotDirtyAccountHitDepthHist.Update(int64(depth)) + snapshotDirtyAccountInexMeter.Mark(1) + snapshotBloomAccountTrueHitMeter.Mark(1) + return nil, nil + } + // Account unknown to this diff, resolve from parent + if diff, ok := dl.parent.(*diffLayer); ok { + return diff.accountRLP(hash, depth+1) + } + // Failed to resolve through diff layers, mark a bloom error and use the disk + snapshotBloomAccountFalseHitMeter.Mark(1) + return dl.parent.AccountRLP(hash) +} + +// Storage directly retrieves the storage data associated with a particular hash, +// within a particular account. If the slot is unknown to this diff, it's parent +// is consulted. +func (dl *diffLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) { + // Check the bloom filter first whether there's even a point in reaching into + // all the maps in all the layers below + dl.lock.RLock() + hit := dl.diffed.Contains(storageBloomHasher{accountHash, storageHash}) + if !hit { + hit = dl.diffed.Contains(destructBloomHasher(accountHash)) + } + dl.lock.RUnlock() + + // If the bloom filter misses, don't even bother with traversing the memory + // diff layers, reach straight into the bottom persistent disk layer + if !hit { + snapshotBloomStorageMissMeter.Mark(1) + return dl.origin.Storage(accountHash, storageHash) + } + // The bloom filter hit, start poking in the internal maps + return dl.storage(accountHash, storageHash, 0) +} + +// storage is an internal version of Storage that skips the bloom filter checks +// and uses the internal maps to try and retrieve the data. It's meant to be +// used if a higher layer's bloom filter hit already. +func (dl *diffLayer) storage(accountHash, storageHash common.Hash, depth int) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.Stale() { + return nil, ErrSnapshotStale + } + // If the account is known locally, try to resolve the slot locally + if storage, ok := dl.storageData[accountHash]; ok { + if data, ok := storage[storageHash]; ok { + snapshotDirtyStorageHitMeter.Mark(1) + snapshotDirtyStorageHitDepthHist.Update(int64(depth)) + if n := len(data); n > 0 { + snapshotDirtyStorageReadMeter.Mark(int64(n)) + } else { + snapshotDirtyStorageInexMeter.Mark(1) + } + snapshotBloomStorageTrueHitMeter.Mark(1) + return data, nil + } + } + // If the account is known locally, but deleted, return an empty slot + if _, ok := dl.destructSet[accountHash]; ok { + snapshotDirtyStorageHitMeter.Mark(1) + snapshotDirtyStorageHitDepthHist.Update(int64(depth)) + snapshotDirtyStorageInexMeter.Mark(1) + snapshotBloomStorageTrueHitMeter.Mark(1) + return nil, nil + } + // Storage slot unknown to this diff, resolve from parent + if diff, ok := dl.parent.(*diffLayer); ok { + return diff.storage(accountHash, storageHash, depth+1) + } + // Failed to resolve through diff layers, mark a bloom error and use the disk + snapshotBloomStorageFalseHitMeter.Mark(1) + return dl.parent.Storage(accountHash, storageHash) +} + +// Update creates a new layer on top of the existing snapshot diff tree with +// the specified data items. +func (dl *diffLayer) Update(blockRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { + return newDiffLayer(dl, blockRoot, destructs, accounts, storage) +} + +// flatten pushes all data from this point downwards, flattening everything into +// a single diff at the bottom. Since usually the lowermost diff is the largest, +// the flattening bulds up from there in reverse. +func (dl *diffLayer) flatten() snapshot { + // If the parent is not diff, we're the first in line, return unmodified + parent, ok := dl.parent.(*diffLayer) + if !ok { + return dl + } + // Parent is a diff, flatten it first (note, apart from weird corned cases, + // flatten will realistically only ever merge 1 layer, so there's no need to + // be smarter about grouping flattens together). + parent = parent.flatten().(*diffLayer) + + parent.lock.Lock() + defer parent.lock.Unlock() + + // Before actually writing all our data to the parent, first ensure that the + // parent hasn't been 'corrupted' by someone else already flattening into it + if atomic.SwapUint32(&parent.stale, 1) != 0 { + panic("parent diff layer is stale") // we've flattened into the same parent from two children, boo + } + // Overwrite all the updated accounts blindly, merge the sorted list + for hash := range dl.destructSet { + parent.destructSet[hash] = struct{}{} + delete(parent.accountData, hash) + delete(parent.storageData, hash) + } + for hash, data := range dl.accountData { + parent.accountData[hash] = data + } + // Overwrite all the updated storage slots (individually) + for accountHash, storage := range dl.storageData { + // If storage didn't exist (or was deleted) in the parent, overwrite blindly + if _, ok := parent.storageData[accountHash]; !ok { + parent.storageData[accountHash] = storage + continue + } + // Storage exists in both parent and child, merge the slots + comboData := parent.storageData[accountHash] + for storageHash, data := range storage { + comboData[storageHash] = data + } + parent.storageData[accountHash] = comboData + } + // Return the combo parent + return &diffLayer{ + parent: parent.parent, + origin: parent.origin, + root: dl.root, + destructSet: parent.destructSet, + accountData: parent.accountData, + storageData: parent.storageData, + storageList: make(map[common.Hash][]common.Hash), + diffed: dl.diffed, + memory: parent.memory + dl.memory, + } +} + +// AccountList returns a sorted list of all accounts in this difflayer, including +// the deleted ones. +// +// Note, the returned slice is not a copy, so do not modify it. +func (dl *diffLayer) AccountList() []common.Hash { + // If an old list already exists, return it + dl.lock.RLock() + list := dl.accountList + dl.lock.RUnlock() + + if list != nil { + return list + } + // No old sorted account list exists, generate a new one + dl.lock.Lock() + defer dl.lock.Unlock() + + dl.accountList = make([]common.Hash, 0, len(dl.destructSet)+len(dl.accountData)) + for hash := range dl.accountData { + dl.accountList = append(dl.accountList, hash) + } + for hash := range dl.destructSet { + if _, ok := dl.accountData[hash]; !ok { + dl.accountList = append(dl.accountList, hash) + } + } + sort.Sort(hashes(dl.accountList)) + return dl.accountList +} + +// StorageList returns a sorted list of all storage slot hashes in this difflayer +// for the given account. +// +// Note, the returned slice is not a copy, so do not modify it. +func (dl *diffLayer) StorageList(accountHash common.Hash) []common.Hash { + // If an old list already exists, return it + dl.lock.RLock() + list := dl.storageList[accountHash] + dl.lock.RUnlock() + + if list != nil { + return list + } + // No old sorted account list exists, generate a new one + dl.lock.Lock() + defer dl.lock.Unlock() + + storageMap := dl.storageData[accountHash] + storageList := make([]common.Hash, 0, len(storageMap)) + for k := range storageMap { + storageList = append(storageList, k) + } + sort.Sort(hashes(storageList)) + dl.storageList[accountHash] = storageList + return storageList +} diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go new file mode 100644 index 0000000000..89814432bb --- /dev/null +++ b/core/state/snapshot/difflayer_test.go @@ -0,0 +1,399 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "math/rand" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/ethdb/memorydb" +) + +func copyDestructs(destructs map[common.Hash]struct{}) map[common.Hash]struct{} { + copy := make(map[common.Hash]struct{}) + for hash := range destructs { + copy[hash] = struct{}{} + } + return copy +} + +func copyAccounts(accounts map[common.Hash][]byte) map[common.Hash][]byte { + copy := make(map[common.Hash][]byte) + for hash, blob := range accounts { + copy[hash] = blob + } + return copy +} + +func copyStorage(storage map[common.Hash]map[common.Hash][]byte) map[common.Hash]map[common.Hash][]byte { + copy := make(map[common.Hash]map[common.Hash][]byte) + for accHash, slots := range storage { + copy[accHash] = make(map[common.Hash][]byte) + for slotHash, blob := range slots { + copy[accHash][slotHash] = blob + } + } + return copy +} + +// TestMergeBasics tests some simple merges +func TestMergeBasics(t *testing.T) { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + // Fill up a parent + for i := 0; i < 100; i++ { + h := randomHash() + data := randomAccount() + + accounts[h] = data + if rand.Intn(4) == 0 { + destructs[h] = struct{}{} + } + if rand.Intn(2) == 0 { + accStorage := make(map[common.Hash][]byte) + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + storage[h] = accStorage + } + } + // Add some (identical) layers on top + parent := newDiffLayer(emptyLayer(), common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child := newDiffLayer(parent, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + child = newDiffLayer(child, common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + // And flatten + merged := (child.flatten()).(*diffLayer) + + { // Check account lists + if have, want := len(merged.accountList), 0; have != want { + t.Errorf("accountList wrong: have %v, want %v", have, want) + } + if have, want := len(merged.AccountList()), len(accounts); have != want { + t.Errorf("AccountList() wrong: have %v, want %v", have, want) + } + if have, want := len(merged.accountList), len(accounts); have != want { + t.Errorf("accountList [2] wrong: have %v, want %v", have, want) + } + } + { // Check account drops + if have, want := len(merged.destructSet), len(destructs); have != want { + t.Errorf("accountDrop wrong: have %v, want %v", have, want) + } + } + { // Check storage lists + i := 0 + for aHash, sMap := range storage { + if have, want := len(merged.storageList), i; have != want { + t.Errorf("[1] storageList wrong: have %v, want %v", have, want) + } + if have, want := len(merged.StorageList(aHash)), len(sMap); have != want { + t.Errorf("[2] StorageList() wrong: have %v, want %v", have, want) + } + if have, want := len(merged.storageList[aHash]), len(sMap); have != want { + t.Errorf("storageList wrong: have %v, want %v", have, want) + } + i++ + } + } +} + +// TestMergeDelete tests some deletion +func TestMergeDelete(t *testing.T) { + var ( + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + // Fill up a parent + h1 := common.HexToHash("0x01") + h2 := common.HexToHash("0x02") + + flipDrops := func() map[common.Hash]struct{} { + return map[common.Hash]struct{}{ + h2: struct{}{}, + } + } + flipAccs := func() map[common.Hash][]byte { + return map[common.Hash][]byte{ + h1: randomAccount(), + } + } + flopDrops := func() map[common.Hash]struct{} { + return map[common.Hash]struct{}{ + h1: struct{}{}, + } + } + flopAccs := func() map[common.Hash][]byte { + return map[common.Hash][]byte{ + h2: randomAccount(), + } + } + // Add some flipAccs-flopping layers on top + parent := newDiffLayer(emptyLayer(), common.Hash{}, flipDrops(), flipAccs(), storage) + child := parent.Update(common.Hash{}, flopDrops(), flopAccs(), storage) + child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage) + child = child.Update(common.Hash{}, flopDrops(), flopAccs(), storage) + child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage) + child = child.Update(common.Hash{}, flopDrops(), flopAccs(), storage) + child = child.Update(common.Hash{}, flipDrops(), flipAccs(), storage) + + if data, _ := child.Account(h1); data == nil { + t.Errorf("last diff layer: expected %x account to be non-nil", h1) + } + if data, _ := child.Account(h2); data != nil { + t.Errorf("last diff layer: expected %x account to be nil", h2) + } + if _, ok := child.destructSet[h1]; ok { + t.Errorf("last diff layer: expected %x drop to be missing", h1) + } + if _, ok := child.destructSet[h2]; !ok { + t.Errorf("last diff layer: expected %x drop to be present", h1) + } + // And flatten + merged := (child.flatten()).(*diffLayer) + + if data, _ := merged.Account(h1); data == nil { + t.Errorf("merged layer: expected %x account to be non-nil", h1) + } + if data, _ := merged.Account(h2); data != nil { + t.Errorf("merged layer: expected %x account to be nil", h2) + } + if _, ok := merged.destructSet[h1]; !ok { // Note, drops stay alive until persisted to disk! + t.Errorf("merged diff layer: expected %x drop to be present", h1) + } + if _, ok := merged.destructSet[h2]; !ok { // Note, drops stay alive until persisted to disk! + t.Errorf("merged diff layer: expected %x drop to be present", h1) + } + // If we add more granular metering of memory, we can enable this again, + // but it's not implemented for now + //if have, want := merged.memory, child.memory; have != want { + // t.Errorf("mem wrong: have %d, want %d", have, want) + //} +} + +// This tests that if we create a new account, and set a slot, and then merge +// it, the lists will be correct. +func TestInsertAndMerge(t *testing.T) { + // Fill up a parent + var ( + acc = common.HexToHash("0x01") + slot = common.HexToHash("0x02") + parent *diffLayer + child *diffLayer + ) + { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + parent = newDiffLayer(emptyLayer(), common.Hash{}, destructs, accounts, storage) + } + { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + accounts[acc] = randomAccount() + storage[acc] = make(map[common.Hash][]byte) + storage[acc][slot] = []byte{0x01} + child = newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + // And flatten + merged := (child.flatten()).(*diffLayer) + { // Check that slot value is present + have, _ := merged.Storage(acc, slot) + if want := []byte{0x01}; !bytes.Equal(have, want) { + t.Errorf("merged slot value wrong: have %x, want %x", have, want) + } + } +} + +func emptyLayer() *diskLayer { + return &diskLayer{ + diskdb: memorydb.New(), + cache: fastcache.New(500 * 1024), + } +} + +// BenchmarkSearch checks how long it takes to find a non-existing key +// BenchmarkSearch-6 200000 10481 ns/op (1K per layer) +// BenchmarkSearch-6 200000 10760 ns/op (10K per layer) +// BenchmarkSearch-6 100000 17866 ns/op +// +// BenchmarkSearch-6 500000 3723 ns/op (10k per layer, only top-level RLock() +func BenchmarkSearch(b *testing.B) { + // First, we set up 128 diff layers, with 1K items each + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + for i := 0; i < 10000; i++ { + accounts[randomHash()] = randomAccount() + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + var layer snapshot + layer = emptyLayer() + for i := 0; i < 128; i++ { + layer = fill(layer) + } + key := crypto.Keccak256Hash([]byte{0x13, 0x38}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + layer.AccountRLP(key) + } +} + +// BenchmarkSearchSlot checks how long it takes to find a non-existing key +// - Number of layers: 128 +// - Each layers contains the account, with a couple of storage slots +// BenchmarkSearchSlot-6 100000 14554 ns/op +// BenchmarkSearchSlot-6 100000 22254 ns/op (when checking parent root using mutex) +// BenchmarkSearchSlot-6 100000 14551 ns/op (when checking parent number using atomic) +// With bloom filter: +// BenchmarkSearchSlot-6 3467835 351 ns/op +func BenchmarkSearchSlot(b *testing.B) { + // First, we set up 128 diff layers, with 1K items each + accountKey := crypto.Keccak256Hash([]byte{0x13, 0x37}) + storageKey := crypto.Keccak256Hash([]byte{0x13, 0x37}) + accountRLP := randomAccount() + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + accounts[accountKey] = accountRLP + + accStorage := make(map[common.Hash][]byte) + for i := 0; i < 5; i++ { + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + storage[accountKey] = accStorage + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + var layer snapshot + layer = emptyLayer() + for i := 0; i < 128; i++ { + layer = fill(layer) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + layer.Storage(accountKey, storageKey) + } +} + +// With accountList and sorting +// BenchmarkFlatten-6 50 29890856 ns/op +// +// Without sorting and tracking accountlist +// BenchmarkFlatten-6 300 5511511 ns/op +func BenchmarkFlatten(b *testing.B) { + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + for i := 0; i < 100; i++ { + accountKey := randomHash() + accounts[accountKey] = randomAccount() + + accStorage := make(map[common.Hash][]byte) + for i := 0; i < 20; i++ { + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + + } + storage[accountKey] = accStorage + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + var layer snapshot + layer = emptyLayer() + for i := 1; i < 128; i++ { + layer = fill(layer) + } + b.StartTimer() + + for i := 1; i < 128; i++ { + dl, ok := layer.(*diffLayer) + if !ok { + break + } + layer = dl.flatten() + } + b.StopTimer() + } +} + +// This test writes ~324M of diff layers to disk, spread over +// - 128 individual layers, +// - each with 200 accounts +// - containing 200 slots +// +// BenchmarkJournal-6 1 1471373923 ns/ops +// BenchmarkJournal-6 1 1208083335 ns/op // bufio writer +func BenchmarkJournal(b *testing.B) { + fill := func(parent snapshot) *diffLayer { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + for i := 0; i < 200; i++ { + accountKey := randomHash() + accounts[accountKey] = randomAccount() + + accStorage := make(map[common.Hash][]byte) + for i := 0; i < 200; i++ { + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + + } + storage[accountKey] = accStorage + } + return newDiffLayer(parent, common.Hash{}, destructs, accounts, storage) + } + layer := snapshot(new(diskLayer)) + for i := 1; i < 128; i++ { + layer = fill(layer) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + layer.Journal(new(bytes.Buffer)) + } +} diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go new file mode 100644 index 0000000000..febb3e6753 --- /dev/null +++ b/core/state/snapshot/disklayer.go @@ -0,0 +1,168 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "sync" + + "github.com/VictoriaMetrics/fastcache" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" +) + +// diskLayer is a low level persistent snapshot built on top of a key-value store. +type diskLayer struct { + diskdb ethdb.KeyValueStore // Key-value store containing the base snapshot + triedb *trie.Database // Trie node cache for reconstuction purposes + cache *fastcache.Cache // Cache to avoid hitting the disk for direct access + + root common.Hash // Root hash of the base snapshot + stale bool // Signals that the layer became stale (state progressed) + + genMarker []byte // Marker for the state that's indexed during initial layer generation + genPending chan struct{} // Notification channel when generation is done (test synchronicity) + genAbort chan chan *generatorStats // Notification channel to abort generating the snapshot in this layer + + lock sync.RWMutex +} + +// Root returns root hash for which this snapshot was made. +func (dl *diskLayer) Root() common.Hash { + return dl.root +} + +// Parent always returns nil as there's no layer below the disk. +func (dl *diskLayer) Parent() snapshot { + return nil +} + +// Stale return whether this layer has become stale (was flattened across) or if +// it's still live. +func (dl *diskLayer) Stale() bool { + dl.lock.RLock() + defer dl.lock.RUnlock() + + return dl.stale +} + +// Account directly retrieves the account associated with a particular hash in +// the snapshot slim data format. +func (dl *diskLayer) Account(hash common.Hash) (*types.SlimAccount, error) { + data, err := dl.AccountRLP(hash) + if err != nil { + return nil, err + } + if len(data) == 0 { // can be both nil and []byte{} + return nil, nil + } + account := new(types.SlimAccount) + if err := rlp.DecodeBytes(data, account); err != nil { + panic(err) + } + return account, nil +} + +// AccountRLP directly retrieves the account RLP associated with a particular +// hash in the snapshot slim data format. +func (dl *diskLayer) AccountRLP(hash common.Hash) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.stale { + return nil, ErrSnapshotStale + } + // If the layer is being generated, ensure the requested hash has already been + // covered by the generator. + if dl.genMarker != nil && bytes.Compare(hash[:], dl.genMarker) > 0 { + return nil, ErrNotCoveredYet + } + // If we're in the disk layer, all diff layers missed + snapshotDirtyAccountMissMeter.Mark(1) + + // Try to retrieve the account from the memory cache + if blob, found := dl.cache.HasGet(nil, hash[:]); found { + snapshotCleanAccountHitMeter.Mark(1) + snapshotCleanAccountReadMeter.Mark(int64(len(blob))) + return blob, nil + } + // Cache doesn't contain account, pull from disk and cache for later + blob := rawdb.ReadAccountSnapshot(dl.diskdb, hash) + dl.cache.Set(hash[:], blob) + + snapshotCleanAccountMissMeter.Mark(1) + if n := len(blob); n > 0 { + snapshotCleanAccountWriteMeter.Mark(int64(n)) + } else { + snapshotCleanAccountInexMeter.Mark(1) + } + return blob, nil +} + +// Storage directly retrieves the storage data associated with a particular hash, +// within a particular account. +func (dl *diskLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) { + dl.lock.RLock() + defer dl.lock.RUnlock() + + // If the layer was flattened into, consider it invalid (any live reference to + // the original should be marked as unusable). + if dl.stale { + return nil, ErrSnapshotStale + } + key := append(accountHash[:], storageHash[:]...) + + // If the layer is being generated, ensure the requested hash has already been + // covered by the generator. + if dl.genMarker != nil && bytes.Compare(key, dl.genMarker) > 0 { + return nil, ErrNotCoveredYet + } + // If we're in the disk layer, all diff layers missed + snapshotDirtyStorageMissMeter.Mark(1) + + // Try to retrieve the storage slot from the memory cache + if blob, found := dl.cache.HasGet(nil, key); found { + snapshotCleanStorageHitMeter.Mark(1) + snapshotCleanStorageReadMeter.Mark(int64(len(blob))) + return blob, nil + } + // Cache doesn't contain storage slot, pull from disk and cache for later + blob := rawdb.ReadStorageSnapshot(dl.diskdb, accountHash, storageHash) + dl.cache.Set(key, blob) + + snapshotCleanStorageMissMeter.Mark(1) + if n := len(blob); n > 0 { + snapshotCleanStorageWriteMeter.Mark(int64(n)) + } else { + snapshotCleanStorageInexMeter.Mark(1) + } + return blob, nil +} + +// Update creates a new layer on top of the existing snapshot diff tree with +// the specified data items. Note, the maps are retained by the method to avoid +// copying everything. +func (dl *diskLayer) Update(blockHash common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer { + return newDiffLayer(dl, blockHash, destructs, accounts, storage) +} diff --git a/core/state/snapshot/disklayer_test.go b/core/state/snapshot/disklayer_test.go new file mode 100644 index 0000000000..652e531b25 --- /dev/null +++ b/core/state/snapshot/disklayer_test.go @@ -0,0 +1,435 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb/memorydb" +) + +// reverse reverses the contents of a byte slice. It's used to update random accs +// with deterministic changes. +func reverse(blob []byte) []byte { + res := make([]byte, len(blob)) + for i, b := range blob { + res[len(blob)-1-i] = b + } + return res +} + +// Tests that merging something into a disk layer persists it into the database +// and invalidates any previously written and cached values. +func TestDiskMerge(t *testing.T) { + // Create some accounts in the disk layer + db := memorydb.New() + + var ( + accNoModNoCache = common.Hash{0x1} + accNoModCache = common.Hash{0x2} + accModNoCache = common.Hash{0x3} + accModCache = common.Hash{0x4} + accDelNoCache = common.Hash{0x5} + accDelCache = common.Hash{0x6} + conNoModNoCache = common.Hash{0x7} + conNoModNoCacheSlot = common.Hash{0x70} + conNoModCache = common.Hash{0x8} + conNoModCacheSlot = common.Hash{0x80} + conModNoCache = common.Hash{0x9} + conModNoCacheSlot = common.Hash{0x90} + conModCache = common.Hash{0xa} + conModCacheSlot = common.Hash{0xa0} + conDelNoCache = common.Hash{0xb} + conDelNoCacheSlot = common.Hash{0xb0} + conDelCache = common.Hash{0xc} + conDelCacheSlot = common.Hash{0xc0} + conNukeNoCache = common.Hash{0xd} + conNukeNoCacheSlot = common.Hash{0xd0} + conNukeCache = common.Hash{0xe} + conNukeCacheSlot = common.Hash{0xe0} + baseRoot = randomHash() + diffRoot = randomHash() + ) + + rawdb.WriteAccountSnapshot(db, accNoModNoCache, accNoModNoCache[:]) + rawdb.WriteAccountSnapshot(db, accNoModCache, accNoModCache[:]) + rawdb.WriteAccountSnapshot(db, accModNoCache, accModNoCache[:]) + rawdb.WriteAccountSnapshot(db, accModCache, accModCache[:]) + rawdb.WriteAccountSnapshot(db, accDelNoCache, accDelNoCache[:]) + rawdb.WriteAccountSnapshot(db, accDelCache, accDelCache[:]) + + rawdb.WriteAccountSnapshot(db, conNoModNoCache, conNoModNoCache[:]) + rawdb.WriteStorageSnapshot(db, conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conNoModCache, conNoModCache[:]) + rawdb.WriteStorageSnapshot(db, conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conModNoCache, conModNoCache[:]) + rawdb.WriteStorageSnapshot(db, conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conModCache, conModCache[:]) + rawdb.WriteStorageSnapshot(db, conModCache, conModCacheSlot, conModCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conDelNoCache, conDelNoCache[:]) + rawdb.WriteStorageSnapshot(db, conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conDelCache, conDelCache[:]) + rawdb.WriteStorageSnapshot(db, conDelCache, conDelCacheSlot, conDelCacheSlot[:]) + + rawdb.WriteAccountSnapshot(db, conNukeNoCache, conNukeNoCache[:]) + rawdb.WriteStorageSnapshot(db, conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:]) + rawdb.WriteAccountSnapshot(db, conNukeCache, conNukeCache[:]) + rawdb.WriteStorageSnapshot(db, conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:]) + + rawdb.WriteSnapshotRoot(db, baseRoot) + + // Create a disk layer based on the above and cache in some data + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + baseRoot: &diskLayer{ + diskdb: db, + cache: fastcache.New(500 * 1024), + root: baseRoot, + }, + }, + } + base := snaps.Snapshot(baseRoot) + base.AccountRLP(accNoModCache) + base.AccountRLP(accModCache) + base.AccountRLP(accDelCache) + base.Storage(conNoModCache, conNoModCacheSlot) + base.Storage(conModCache, conModCacheSlot) + base.Storage(conDelCache, conDelCacheSlot) + base.Storage(conNukeCache, conNukeCacheSlot) + + // Modify or delete some accounts, flatten everything onto disk + if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{ + accDelNoCache: struct{}{}, + accDelCache: struct{}{}, + conNukeNoCache: struct{}{}, + conNukeCache: struct{}{}, + }, map[common.Hash][]byte{ + accModNoCache: reverse(accModNoCache[:]), + accModCache: reverse(accModCache[:]), + }, map[common.Hash]map[common.Hash][]byte{ + conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])}, + conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])}, + conDelNoCache: {conDelNoCacheSlot: nil}, + conDelCache: {conDelCacheSlot: nil}, + }); err != nil { + t.Fatalf("failed to update snapshot tree: %v", err) + } + if err := snaps.Cap(diffRoot, 0); err != nil { + t.Fatalf("failed to flatten snapshot tree: %v", err) + } + // Retrieve all the data through the disk layer and validate it + base = snaps.Snapshot(diffRoot) + if _, ok := base.(*diskLayer); !ok { + t.Fatalf("update not flattend into the disk layer") + } + + // assertAccount ensures that an account matches the given blob. + assertAccount := func(account common.Hash, data []byte) { + t.Helper() + blob, err := base.AccountRLP(account) + if err != nil { + t.Errorf("account access (%x) failed: %v", account, err) + } else if !bytes.Equal(blob, data) { + t.Errorf("account access (%x) mismatch: have %x, want %x", account, blob, data) + } + } + assertAccount(accNoModNoCache, accNoModNoCache[:]) + assertAccount(accNoModCache, accNoModCache[:]) + assertAccount(accModNoCache, reverse(accModNoCache[:])) + assertAccount(accModCache, reverse(accModCache[:])) + assertAccount(accDelNoCache, nil) + assertAccount(accDelCache, nil) + + // assertStorage ensures that a storage slot matches the given blob. + assertStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + blob, err := base.Storage(account, slot) + if err != nil { + t.Errorf("storage access (%x:%x) failed: %v", account, slot, err) + } else if !bytes.Equal(blob, data) { + t.Errorf("storage access (%x:%x) mismatch: have %x, want %x", account, slot, blob, data) + } + } + assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertStorage(conDelCache, conDelCacheSlot, nil) + assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertStorage(conNukeCache, conNukeCacheSlot, nil) + + // Retrieve all the data directly from the database and validate it + + // assertDatabaseAccount ensures that an account from the database matches the given blob. + assertDatabaseAccount := func(account common.Hash, data []byte) { + t.Helper() + if blob := rawdb.ReadAccountSnapshot(db, account); !bytes.Equal(blob, data) { + t.Errorf("account database access (%x) mismatch: have %x, want %x", account, blob, data) + } + } + assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:]) + assertDatabaseAccount(accNoModCache, accNoModCache[:]) + assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:])) + assertDatabaseAccount(accModCache, reverse(accModCache[:])) + assertDatabaseAccount(accDelNoCache, nil) + assertDatabaseAccount(accDelCache, nil) + + // assertDatabaseStorage ensures that a storage slot from the database matches the given blob. + assertDatabaseStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + if blob := rawdb.ReadStorageSnapshot(db, account, slot); !bytes.Equal(blob, data) { + t.Errorf("storage database access (%x:%x) mismatch: have %x, want %x", account, slot, blob, data) + } + } + assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertDatabaseStorage(conDelCache, conDelCacheSlot, nil) + assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil) +} + +// Tests that merging something into a disk layer persists it into the database +// and invalidates any previously written and cached values, discarding anything +// after the in-progress generation marker. +func TestDiskPartialMerge(t *testing.T) { + // Iterate the test a few times to ensure we pick various internal orderings + // for the data slots as well as the progress marker. + for i := 0; i < 1024; i++ { + // Create some accounts in the disk layer + db := memorydb.New() + + var ( + accNoModNoCache = randomHash() + accNoModCache = randomHash() + accModNoCache = randomHash() + accModCache = randomHash() + accDelNoCache = randomHash() + accDelCache = randomHash() + conNoModNoCache = randomHash() + conNoModNoCacheSlot = randomHash() + conNoModCache = randomHash() + conNoModCacheSlot = randomHash() + conModNoCache = randomHash() + conModNoCacheSlot = randomHash() + conModCache = randomHash() + conModCacheSlot = randomHash() + conDelNoCache = randomHash() + conDelNoCacheSlot = randomHash() + conDelCache = randomHash() + conDelCacheSlot = randomHash() + conNukeNoCache = randomHash() + conNukeNoCacheSlot = randomHash() + conNukeCache = randomHash() + conNukeCacheSlot = randomHash() + baseRoot = randomHash() + diffRoot = randomHash() + genMarker = append(randomHash().Bytes(), randomHash().Bytes()...) + ) + + // insertAccount injects an account into the database if it's after the + // generator marker, drops the op otherwise. This is needed to seed the + // database with a valid starting snapshot. + insertAccount := func(account common.Hash, data []byte) { + if bytes.Compare(account[:], genMarker) <= 0 { + rawdb.WriteAccountSnapshot(db, account, data[:]) + } + } + insertAccount(accNoModNoCache, accNoModNoCache[:]) + insertAccount(accNoModCache, accNoModCache[:]) + insertAccount(accModNoCache, accModNoCache[:]) + insertAccount(accModCache, accModCache[:]) + insertAccount(accDelNoCache, accDelNoCache[:]) + insertAccount(accDelCache, accDelCache[:]) + + // insertStorage injects a storage slot into the database if it's after + // the generator marker, drops the op otherwise. This is needed to seed + // the database with a valid starting snapshot. + insertStorage := func(account common.Hash, slot common.Hash, data []byte) { + if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 { + rawdb.WriteStorageSnapshot(db, account, slot, data[:]) + } + } + insertAccount(conNoModNoCache, conNoModNoCache[:]) + insertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + insertAccount(conNoModCache, conNoModCache[:]) + insertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + insertAccount(conModNoCache, conModNoCache[:]) + insertStorage(conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:]) + insertAccount(conModCache, conModCache[:]) + insertStorage(conModCache, conModCacheSlot, conModCacheSlot[:]) + insertAccount(conDelNoCache, conDelNoCache[:]) + insertStorage(conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:]) + insertAccount(conDelCache, conDelCache[:]) + insertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:]) + + insertAccount(conNukeNoCache, conNukeNoCache[:]) + insertStorage(conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:]) + insertAccount(conNukeCache, conNukeCache[:]) + insertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:]) + + rawdb.WriteSnapshotRoot(db, baseRoot) + + // Create a disk layer based on the above using a random progress marker + // and cache in some data. + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + baseRoot: &diskLayer{ + diskdb: db, + cache: fastcache.New(500 * 1024), + root: baseRoot, + }, + }, + } + snaps.layers[baseRoot].(*diskLayer).genMarker = genMarker + base := snaps.Snapshot(baseRoot) + + // assertAccount ensures that an account matches the given blob if it's + // already covered by the disk snapshot, and errors out otherwise. + assertAccount := func(account common.Hash, data []byte) { + t.Helper() + blob, err := base.AccountRLP(account) + if bytes.Compare(account[:], genMarker) > 0 && err != ErrNotCoveredYet { + t.Fatalf("test %d: post-marker (%x) account access (%x) succeeded: %x", i, genMarker, account, blob) + } + if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) account access (%x) mismatch: have %x, want %x", i, genMarker, account, blob, data) + } + } + assertAccount(accNoModCache, accNoModCache[:]) + assertAccount(accModCache, accModCache[:]) + assertAccount(accDelCache, accDelCache[:]) + + // assertStorage ensures that a storage slot matches the given blob if + // it's already covered by the disk snapshot, and errors out otherwise. + assertStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + blob, err := base.Storage(account, slot) + if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && err != ErrNotCoveredYet { + t.Fatalf("test %d: post-marker (%x) storage access (%x:%x) succeeded: %x", i, genMarker, account, slot, blob) + } + if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) storage access (%x:%x) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data) + } + } + assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertStorage(conModCache, conModCacheSlot, conModCacheSlot[:]) + assertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:]) + assertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:]) + + // Modify or delete some accounts, flatten everything onto disk + if err := snaps.Update(diffRoot, baseRoot, map[common.Hash]struct{}{ + accDelNoCache: struct{}{}, + accDelCache: struct{}{}, + conNukeNoCache: struct{}{}, + conNukeCache: struct{}{}, + }, map[common.Hash][]byte{ + accModNoCache: reverse(accModNoCache[:]), + accModCache: reverse(accModCache[:]), + }, map[common.Hash]map[common.Hash][]byte{ + conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])}, + conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])}, + conDelNoCache: {conDelNoCacheSlot: nil}, + conDelCache: {conDelCacheSlot: nil}, + }); err != nil { + t.Fatalf("test %d: failed to update snapshot tree: %v", i, err) + } + if err := snaps.Cap(diffRoot, 0); err != nil { + t.Fatalf("test %d: failed to flatten snapshot tree: %v", i, err) + } + // Retrieve all the data through the disk layer and validate it + base = snaps.Snapshot(diffRoot) + if _, ok := base.(*diskLayer); !ok { + t.Fatalf("test %d: update not flattend into the disk layer", i) + } + assertAccount(accNoModNoCache, accNoModNoCache[:]) + assertAccount(accNoModCache, accNoModCache[:]) + assertAccount(accModNoCache, reverse(accModNoCache[:])) + assertAccount(accModCache, reverse(accModCache[:])) + assertAccount(accDelNoCache, nil) + assertAccount(accDelCache, nil) + + assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertStorage(conDelCache, conDelCacheSlot, nil) + assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertStorage(conNukeCache, conNukeCacheSlot, nil) + + // Retrieve all the data directly from the database and validate it + + // assertDatabaseAccount ensures that an account inside the database matches + // the given blob if it's already covered by the disk snapshot, and does not + // exist otherwise. + assertDatabaseAccount := func(account common.Hash, data []byte) { + t.Helper() + blob := rawdb.ReadAccountSnapshot(db, account) + if bytes.Compare(account[:], genMarker) > 0 && blob != nil { + t.Fatalf("test %d: post-marker (%x) account database access (%x) succeeded: %x", i, genMarker, account, blob) + } + if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) account database access (%x) mismatch: have %x, want %x", i, genMarker, account, blob, data) + } + } + assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:]) + assertDatabaseAccount(accNoModCache, accNoModCache[:]) + assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:])) + assertDatabaseAccount(accModCache, reverse(accModCache[:])) + assertDatabaseAccount(accDelNoCache, nil) + assertDatabaseAccount(accDelCache, nil) + + // assertDatabaseStorage ensures that a storage slot inside the database + // matches the given blob if it's already covered by the disk snapshot, + // and does not exist otherwise. + assertDatabaseStorage := func(account common.Hash, slot common.Hash, data []byte) { + t.Helper() + blob := rawdb.ReadStorageSnapshot(db, account, slot) + if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && blob != nil { + t.Fatalf("test %d: post-marker (%x) storage database access (%x:%x) succeeded: %x", i, genMarker, account, slot, blob) + } + if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) { + t.Fatalf("test %d: pre-marker (%x) storage database access (%x:%x) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data) + } + } + assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:]) + assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:]) + assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:])) + assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:])) + assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil) + assertDatabaseStorage(conDelCache, conDelCacheSlot, nil) + assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil) + assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil) + } +} + +// Tests that merging something into a disk layer persists it into the database +// and invalidates any previously written and cached values, discarding anything +// after the in-progress generation marker. +// +// This test case is a tiny specialized case of TestDiskPartialMerge, which tests +// some very specific cornercases that random tests won't ever trigger. +func TestDiskMidAccountPartialMerge(t *testing.T) { +} diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go new file mode 100644 index 0000000000..ea5b59a72f --- /dev/null +++ b/core/state/snapshot/generate.go @@ -0,0 +1,286 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "encoding/binary" + "fmt" + "math/big" + "time" + + "github.com/VictoriaMetrics/fastcache" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" +) + +var ( + // emptyRoot is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // emptyCode is the known hash of the empty EVM bytecode. + emptyCode = crypto.Keccak256Hash(nil) +) + +// generatorStats is a collection of statistics gathered by the snapshot generator +// for logging purposes. +type generatorStats struct { + origin uint64 // Origin prefix where generation started + start time.Time // Timestamp when generation started + accounts uint64 // Number of accounts indexed + slots uint64 // Number of storage slots indexed + storage common.StorageSize // Account and storage slot size +} + +// Log creates an contextual log with the given message and the context pulled +// from the internally maintained statistics. +func (gs *generatorStats) Log(msg string, marker []byte) { + var ctx []interface{} + + // Figure out whether we're after or within an account + switch len(marker) { + case common.HashLength: + ctx = append(ctx, []interface{}{"at", common.BytesToHash(marker)}...) + case 2 * common.HashLength: + ctx = append(ctx, []interface{}{ + "in", common.BytesToHash(marker[:common.HashLength]), + "at", common.BytesToHash(marker[common.HashLength:]), + }...) + } + // Add the usual measurements + ctx = append(ctx, []interface{}{ + "accounts", gs.accounts, + "slots", gs.slots, + "storage", gs.storage, + "elapsed", common.PrettyDuration(time.Since(gs.start)), + }...) + // Calculate the estimated indexing time based on current stats + if len(marker) > 0 { + if done := binary.BigEndian.Uint64(marker[:8]) - gs.origin; done > 0 { + left := math.MaxUint64 - binary.BigEndian.Uint64(marker[:8]) + + speed := done/uint64(time.Since(gs.start)/time.Millisecond+1) + 1 // +1s to avoid division by zero + ctx = append(ctx, []interface{}{ + "eta", common.PrettyDuration(time.Duration(left/speed) * time.Millisecond), + }...) + } + } + log.Info(msg, ctx...) +} + +// generateSnapshot regenerates a brand new snapshot based on an existing state +// database and head block asynchronously. The snapshot is returned immediately +// and generation is continued in the background until done. +func generateSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) *diskLayer { + // Create a new disk layer with an initialized state marker at zero + var ( + stats = &generatorStats{start: time.Now()} + batch = diskdb.NewBatch() + genMarker = []byte{} // Initialized but empty! + ) + // Create a new disk layer with an initialized state marker at zero + rawdb.WriteSnapshotRoot(diskdb, root) + if err := batch.Write(); err != nil { + log.Crit("Failed to write initialized state marker", "err", err) + } + base := &diskLayer{ + diskdb: diskdb, + triedb: triedb, + root: root, + cache: fastcache.New(cache * 1024 * 1024), + genMarker: genMarker, // Initialized but empty! + genPending: make(chan struct{}), + genAbort: make(chan chan *generatorStats), + } + go base.generate(stats) + log.Debug("Start snapshot generation", "root", root) + return base +} + +// journalProgress persists the generator stats into the database to resume later. +func journalProgress(db ethdb.KeyValueWriter, marker []byte, stats *generatorStats) { + // Write out the generator marker. Note it's a standalone disk layer generator + // which is not mixed with journal. It's ok if the generator is persisted while + // journal is not. + entry := journalGenerator{ + Done: marker == nil, + Marker: marker, + } + if stats != nil { + entry.Accounts = stats.accounts + entry.Slots = stats.slots + entry.Storage = uint64(stats.storage) + } + blob, err := rlp.EncodeToBytes(entry) + if err != nil { + panic(err) // Cannot happen, here to catch dev errors + } + var logstr string + switch { + case marker == nil: + logstr = "done" + case bytes.Equal(marker, []byte{}): + logstr = "empty" + case len(marker) == common.HashLength: + logstr = fmt.Sprintf("%#x", marker) + default: + logstr = fmt.Sprintf("%#x:%#x", marker[:common.HashLength], marker[common.HashLength:]) + } + log.Debug("Journalled generator progress", "progress", logstr) + rawdb.WriteSnapshotGenerator(db, blob) +} + +// generate is a background thread that iterates over the state and storage tries, +// constructing the state snapshot. All the arguments are purely for statistics +// gethering and logging, since the method surfs the blocks as they arrive, often +// being restarted. +func (dl *diskLayer) generate(stats *generatorStats) { + // Create an account and state iterator pointing to the current generator marker + accTrie, err := trie.NewSecure(dl.root, dl.triedb) + if err != nil { + // The account trie is missing (GC), surf the chain until one becomes available + stats.Log("Trie missing, state snapshotting paused", dl.genMarker) + + abort := <-dl.genAbort + abort <- stats + return + } + stats.Log("Resuming state snapshot generation", dl.genMarker) + + var accMarker []byte + if len(dl.genMarker) > 0 { // []byte{} is the start, use nil for that + accMarker = dl.genMarker[:common.HashLength] + } + accIt := trie.NewIterator(accTrie.NodeIterator(accMarker)) + batch := dl.diskdb.NewBatch() + + // Iterate from the previous marker and continue generating the state snapshot + logged := time.Now() + for accIt.Next() { + // Retrieve the current account and flatten it into the internal format + accountHash := common.BytesToHash(accIt.Key) + + var acc struct { + Nonce uint64 + Balance *big.Int + Root common.Hash + CodeHash []byte + } + if err := rlp.DecodeBytes(accIt.Value, &acc); err != nil { + log.Crit("Invalid account encountered during snapshot creation", "err", err) + } + data := types.SlimAccountRLP(acc) + + // If the account is not yet in-progress, write it out + if accMarker == nil || !bytes.Equal(accountHash[:], accMarker) { + rawdb.WriteAccountSnapshot(batch, accountHash, data) + stats.storage += common.StorageSize(1 + common.HashLength + len(data)) + stats.accounts++ + } + // If we've exceeded our batch allowance or termination was requested, flush to disk + var abort chan *generatorStats + select { + case abort = <-dl.genAbort: + default: + } + if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil { + // Only write and set the marker if we actually did something useful + if batch.ValueSize() > 0 { + batch.Write() + batch.Reset() + + dl.lock.Lock() + dl.genMarker = accountHash[:] + dl.lock.Unlock() + } + if abort != nil { + stats.Log("Aborting state snapshot generation", accountHash[:]) + abort <- stats + return + } + } + // If the account is in-progress, continue where we left off (otherwise iterate all) + if acc.Root != emptyRoot { + storeTrie, err := trie.NewSecure(acc.Root, dl.triedb) + if err != nil { + log.Crit("Storage trie inaccessible for snapshot generation", "err", err) + } + var storeMarker []byte + if accMarker != nil && bytes.Equal(accountHash[:], accMarker) && len(dl.genMarker) > common.HashLength { + storeMarker = dl.genMarker[common.HashLength:] + } + storeIt := trie.NewIterator(storeTrie.NodeIterator(storeMarker)) + for storeIt.Next() { + rawdb.WriteStorageSnapshot(batch, accountHash, common.BytesToHash(storeIt.Key), storeIt.Value) + stats.storage += common.StorageSize(1 + 2*common.HashLength + len(storeIt.Value)) + stats.slots++ + + // If we've exceeded our batch allowance or termination was requested, flush to disk + var abort chan *generatorStats + select { + case abort = <-dl.genAbort: + default: + } + if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil { + // Only write and set the marker if we actually did something useful + if batch.ValueSize() > 0 { + batch.Write() + batch.Reset() + + dl.lock.Lock() + dl.genMarker = append(accountHash[:], storeIt.Key...) + dl.lock.Unlock() + } + if abort != nil { + stats.Log("Aborting state snapshot generation", append(accountHash[:], storeIt.Key...)) + abort <- stats + return + } + } + } + } + if time.Since(logged) > 8*time.Second { + stats.Log("Generating state snapshot", accIt.Key) + logged = time.Now() + } + // Some account processed, unmark the marker + accMarker = nil + } + // Snapshot fully generated, set the marker to nil + if batch.ValueSize() > 0 { + batch.Write() + } + log.Info("Generated state snapshot", "accounts", stats.accounts, "slots", stats.slots, + "storage", stats.storage, "elapsed", common.PrettyDuration(time.Since(stats.start))) + + dl.lock.Lock() + dl.genMarker = nil + close(dl.genPending) + dl.lock.Unlock() + + // Someone will be looking for us, wait it out + abort := <-dl.genAbort + abort <- nil +} diff --git a/core/state/snapshot/iterator.go b/core/state/snapshot/iterator.go new file mode 100644 index 0000000000..b62fb30e34 --- /dev/null +++ b/core/state/snapshot/iterator.go @@ -0,0 +1,221 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "fmt" + "sort" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" +) + +// Iterator is an iterator to step over all the accounts or the specific +// storage in a snapshot which may or may not be composed of multiple layers. +type Iterator interface { + // Next steps the iterator forward one element, returning false if exhausted, + // or an error if iteration failed for some reason (e.g. root being iterated + // becomes stale and garbage collected). + Next() bool + + // Error returns any failure that occurred during iteration, which might have + // caused a premature iteration exit (e.g. snapshot stack becoming stale). + Error() error + + // Hash returns the hash of the account or storage slot the iterator is + // currently at. + Hash() common.Hash + + // Release releases associated resources. Release should always succeed and + // can be called multiple times without causing error. + Release() +} + +// AccountIterator is an iterator to step over all the accounts in a snapshot, +// which may or may not be composed of multiple layers. +type AccountIterator interface { + Iterator + + // Account returns the RLP encoded slim account the iterator is currently at. + // An error will be returned if the iterator becomes invalid + Account() []byte +} + +// diffAccountIterator is an account iterator that steps over the accounts (both +// live and deleted) contained within a single diff layer. Higher order iterators +// will use the deleted accounts to skip deeper iterators. +type diffAccountIterator struct { + // curHash is the current hash the iterator is positioned on. The field is + // explicitly tracked since the referenced diff layer might go stale after + // the iterator was positioned and we don't want to fail accessing the old + // hash as long as the iterator is not touched any more. + curHash common.Hash + + layer *diffLayer // Live layer to retrieve values from + keys []common.Hash // Keys left in the layer to iterate + fail error // Any failures encountered (stale) +} + +// StorageIterator is an iterator to step over the specific storage in a snapshot, +// which may or may not be composed of multiple layers. +type StorageIterator interface { + Iterator + + // Slot returns the storage slot the iterator is currently at. An error will + // be returned if the iterator becomes invalid + Slot() []byte +} + +// AccountIterator creates an account iterator over a single diff layer. +func (dl *diffLayer) AccountIterator(seek common.Hash) AccountIterator { + // Seek out the requested starting account + hashes := dl.AccountList() + index := sort.Search(len(hashes), func(i int) bool { + return bytes.Compare(seek[:], hashes[i][:]) < 0 + }) + // Assemble and returned the already seeked iterator + return &diffAccountIterator{ + layer: dl, + keys: hashes[index:], + } +} + +// Next steps the iterator forward one element, returning false if exhausted. +func (it *diffAccountIterator) Next() bool { + // If the iterator was already stale, consider it a programmer error. Although + // we could just return false here, triggering this path would probably mean + // somebody forgot to check for Error, so lets blow up instead of undefined + // behavior that's hard to debug. + if it.fail != nil { + panic(fmt.Sprintf("called Next of failed iterator: %v", it.fail)) + } + // Stop iterating if all keys were exhausted + if len(it.keys) == 0 { + return false + } + if it.layer.Stale() { + it.fail, it.keys = ErrSnapshotStale, nil + return false + } + // Iterator seems to be still alive, retrieve and cache the live hash + it.curHash = it.keys[0] + // key cached, shift the iterator and notify the user of success + it.keys = it.keys[1:] + return true +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +func (it *diffAccountIterator) Error() error { + return it.fail +} + +// Hash returns the hash of the account the iterator is currently at. +func (it *diffAccountIterator) Hash() common.Hash { + return it.curHash +} + +// Account returns the RLP encoded slim account the iterator is currently at. +// This method may _fail_, if the underlying layer has been flattened between +// the call to Next and Acccount. That type of error will set it.Err. +// This method assumes that flattening does not delete elements from +// the accountdata mapping (writing nil into it is fine though), and will panic +// if elements have been deleted. +func (it *diffAccountIterator) Account() []byte { + it.layer.lock.RLock() + blob, ok := it.layer.accountData[it.curHash] + if !ok { + if _, ok := it.layer.destructSet[it.curHash]; ok { + return nil + } + panic(fmt.Sprintf("iterator referenced non-existent account: %x", it.curHash)) + } + it.layer.lock.RUnlock() + if it.layer.Stale() { + it.fail, it.keys = ErrSnapshotStale, nil + } + return blob +} + +// Release is a noop for diff account iterators as there are no held resources. +func (it *diffAccountIterator) Release() {} + +// diskAccountIterator is an account iterator that steps over the live accounts +// contained within a disk layer. +type diskAccountIterator struct { + layer *diskLayer + it ethdb.Iterator +} + +// AccountIterator creates an account iterator over a disk layer. +func (dl *diskLayer) AccountIterator(seek common.Hash) AccountIterator { + pos := common.TrimRightZeroes(seek[:]) + return &diskAccountIterator{ + layer: dl, + it: dl.diskdb.NewIterator(rawdb.SnapshotAccountPrefix, pos), + } +} + +// Next steps the iterator forward one element, returning false if exhausted. +func (it *diskAccountIterator) Next() bool { + // If the iterator was already exhausted, don't bother + if it.it == nil { + return false + } + // Try to advance the iterator and release it if we reached the end + for { + if !it.it.Next() || !bytes.HasPrefix(it.it.Key(), rawdb.SnapshotAccountPrefix) { + it.it.Release() + it.it = nil + return false + } + if len(it.it.Key()) == len(rawdb.SnapshotAccountPrefix)+common.HashLength { + break + } + } + return true +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +// +// A diff layer is immutable after creation content wise and can always be fully +// iterated without error, so this method always returns nil. +func (it *diskAccountIterator) Error() error { + return it.it.Error() +} + +// Hash returns the hash of the account the iterator is currently at. +func (it *diskAccountIterator) Hash() common.Hash { + return common.BytesToHash(it.it.Key()) +} + +// Account returns the RLP encoded slim account the iterator is currently at. +func (it *diskAccountIterator) Account() []byte { + return it.it.Value() +} + +// Release releases the database snapshot held during iteration. +func (it *diskAccountIterator) Release() { + // The iterator is auto-released on exhaustion, so make sure it's still alive + if it.it != nil { + it.it.Release() + it.it = nil + } +} diff --git a/core/state/snapshot/iterator_binary.go b/core/state/snapshot/iterator_binary.go new file mode 100644 index 0000000000..d8df968ea5 --- /dev/null +++ b/core/state/snapshot/iterator_binary.go @@ -0,0 +1,115 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + + "github.com/tomochain/tomochain/common" +) + +// binaryAccountIterator is a simplistic iterator to step over the accounts in +// a snapshot, which may or may npt be composed of multiple layers. Performance +// wise this iterator is slow, it's meant for cross validating the fast one, +type binaryAccountIterator struct { + a *diffAccountIterator + b AccountIterator + aDone bool + bDone bool + k common.Hash + fail error +} + +// newBinaryAccountIterator creates a simplistic account iterator to step over +// all the accounts in a slow, but eaily verifiable way. +func (dl *diffLayer) newBinaryAccountIterator() AccountIterator { + parent, ok := dl.parent.(*diffLayer) + if !ok { + // parent is the disk layer + return dl.AccountIterator(common.Hash{}) + } + l := &binaryAccountIterator{ + a: dl.AccountIterator(common.Hash{}).(*diffAccountIterator), + b: parent.newBinaryAccountIterator(), + } + l.aDone = !l.a.Next() + l.bDone = !l.b.Next() + return l +} + +// Next steps the iterator forward one element, returning false if exhausted, +// or an error if iteration failed for some reason (e.g. root being iterated +// becomes stale and garbage collected). +func (it *binaryAccountIterator) Next() bool { + if it.aDone && it.bDone { + return false + } + nextB := it.b.Hash() +first: + nextA := it.a.Hash() + if it.aDone { + it.bDone = !it.b.Next() + it.k = nextB + return true + } + if it.bDone { + it.aDone = !it.a.Next() + it.k = nextA + return true + } + if diff := bytes.Compare(nextA[:], nextB[:]); diff < 0 { + it.aDone = !it.a.Next() + it.k = nextA + return true + } else if diff == 0 { + // Now we need to advance one of them + it.aDone = !it.a.Next() + goto first + } + it.bDone = !it.b.Next() + it.k = nextB + return true +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +func (it *binaryAccountIterator) Error() error { + return it.fail +} + +// Hash returns the hash of the account the iterator is currently at. +func (it *binaryAccountIterator) Hash() common.Hash { + return it.k +} + +// Account returns the RLP encoded slim account the iterator is currently at, or +// nil if the iterated snapshot stack became stale (you can check Error after +// to see if it failed or not). +func (it *binaryAccountIterator) Account() []byte { + blob, err := it.a.layer.AccountRLP(it.k) + if err != nil { + it.fail = err + return nil + } + return blob +} + +// Release recursively releases all the iterators in the stack. +func (it *binaryAccountIterator) Release() { + it.a.Release() + it.b.Release() +} diff --git a/core/state/snapshot/iterator_fast.go b/core/state/snapshot/iterator_fast.go new file mode 100644 index 0000000000..afbe70c2bc --- /dev/null +++ b/core/state/snapshot/iterator_fast.go @@ -0,0 +1,302 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "fmt" + "sort" + + "github.com/tomochain/tomochain/common" +) + +// weightedAccountIterator is an account iterator with an assigned weight. It is +// used to prioritise which account is the correct one if multiple iterators find +// the same one (modified in multiple consecutive blocks). +type weightedAccountIterator struct { + it AccountIterator + priority int +} + +// weightedAccountIterators is a set of iterators implementing the sort.Interface. +type weightedAccountIterators []*weightedAccountIterator + +// Len implements sort.Interface, returning the number of active iterators. +func (its weightedAccountIterators) Len() int { return len(its) } + +// Less implements sort.Interface, returning which of two iterators in the stack +// is before the other. +func (its weightedAccountIterators) Less(i, j int) bool { + // Order the iterators primarily by the account hashes + hashI := its[i].it.Hash() + hashJ := its[j].it.Hash() + + switch bytes.Compare(hashI[:], hashJ[:]) { + case -1: + return true + case 1: + return false + } + // Same account in multiple layers, split by priority + return its[i].priority < its[j].priority +} + +// Swap implements sort.Interface, swapping two entries in the iterator stack. +func (its weightedAccountIterators) Swap(i, j int) { + its[i], its[j] = its[j], its[i] +} + +// fastAccountIterator is a more optimized multi-layer iterator which maintains a +// direct mapping of all iterators leading down to the bottom layer. +type fastAccountIterator struct { + tree *Tree // Snapshot tree to reinitialize stale sub-iterators with + root common.Hash // Root hash to reinitialize stale sub-iterators through + curAccount []byte + + iterators weightedAccountIterators + initiated bool + fail error +} + +// newFastAccountIterator creates a new hierarhical account iterator with one +// element per diff layer. The returned combo iterator can be used to walk over +// the entire snapshot diff stack simultaneously. +func newFastAccountIterator(tree *Tree, root common.Hash, seek common.Hash) (AccountIterator, error) { + snap := tree.Snapshot(root) + if snap == nil { + return nil, fmt.Errorf("unknown snapshot: %x", root) + } + fi := &fastAccountIterator{ + tree: tree, + root: root, + } + current := snap.(snapshot) + for depth := 0; current != nil; depth++ { + fi.iterators = append(fi.iterators, &weightedAccountIterator{ + it: current.AccountIterator(seek), + priority: depth, + }) + current = current.Parent() + } + fi.init() + return fi, nil +} + +// init walks over all the iterators and resolves any clashes between them, after +// which it prepares the stack for step-by-step iteration. +func (fi *fastAccountIterator) init() { + // Track which account hashes are iterators positioned on + var positioned = make(map[common.Hash]int) + + // Position all iterators and track how many remain live + for i := 0; i < len(fi.iterators); i++ { + // Retrieve the first element and if it clashes with a previous iterator, + // advance either the current one or the old one. Repeat until nothing is + // clashing any more. + it := fi.iterators[i] + for { + // If the iterator is exhausted, drop it off the end + if !it.it.Next() { + it.it.Release() + last := len(fi.iterators) - 1 + + fi.iterators[i] = fi.iterators[last] + fi.iterators[last] = nil + fi.iterators = fi.iterators[:last] + + i-- + break + } + // The iterator is still alive, check for collisions with previous ones + hash := it.it.Hash() + if other, exist := positioned[hash]; !exist { + positioned[hash] = i + break + } else { + // Iterators collide, one needs to be progressed, use priority to + // determine which. + // + // This whole else-block can be avoided, if we instead + // do an initial priority-sort of the iterators. If we do that, + // then we'll only wind up here if a lower-priority (preferred) iterator + // has the same value, and then we will always just continue. + // However, it costs an extra sort, so it's probably not better + if fi.iterators[other].priority < it.priority { + // The 'it' should be progressed + continue + } else { + // The 'other' should be progressed, swap them + it = fi.iterators[other] + fi.iterators[other], fi.iterators[i] = fi.iterators[i], fi.iterators[other] + continue + } + } + } + } + // Re-sort the entire list + sort.Sort(fi.iterators) + fi.initiated = false +} + +// Next steps the iterator forward one element, returning false if exhausted. +func (fi *fastAccountIterator) Next() bool { + if len(fi.iterators) == 0 { + return false + } + if !fi.initiated { + // Don't forward first time -- we had to 'Next' once in order to + // do the sorting already + fi.initiated = true + fi.curAccount = fi.iterators[0].it.Account() + if innerErr := fi.iterators[0].it.Error(); innerErr != nil { + fi.fail = innerErr + return false + } + if fi.curAccount != nil { + return true + } + // Implicit else: we've hit a nil-account, and need to fall through to the + // loop below to land on something non-nil + } + // If an account is deleted in one of the layers, the key will still be there, + // but the actual value will be nil. However, the iterator should not + // export nil-values (but instead simply omit the key), so we need to loop + // here until we either + // - get a non-nil value, + // - hit an error, + // - or exhaust the iterator + for { + if !fi.next(0) { + return false // exhausted + } + fi.curAccount = fi.iterators[0].it.Account() + if innerErr := fi.iterators[0].it.Error(); innerErr != nil { + fi.fail = innerErr + return false // error + } + if fi.curAccount != nil { + break // non-nil value found + } + } + return true +} + +// next handles the next operation internally and should be invoked when we know +// that two elements in the list may have the same value. +// +// For example, if the iterated hashes become [2,3,5,5,8,9,10], then we should +// invoke next(3), which will call Next on elem 3 (the second '5') and will +// cascade along the list, applying the same operation if needed. +func (fi *fastAccountIterator) next(idx int) bool { + // If this particular iterator got exhausted, remove it and return true (the + // next one is surely not exhausted yet, otherwise it would have been removed + // already). + if it := fi.iterators[idx].it; !it.Next() { + it.Release() + + fi.iterators = append(fi.iterators[:idx], fi.iterators[idx+1:]...) + return len(fi.iterators) > 0 + } + // If there's noone left to cascade into, return + if idx == len(fi.iterators)-1 { + return true + } + // We next-ed the iterator at 'idx', now we may have to re-sort that element + var ( + cur, next = fi.iterators[idx], fi.iterators[idx+1] + curHash, nextHash = cur.it.Hash(), next.it.Hash() + ) + if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 { + // It is still in correct place + return true + } else if diff == 0 && cur.priority < next.priority { + // So still in correct place, but we need to iterate on the next + fi.next(idx + 1) + return true + } + // At this point, the iterator is in the wrong location, but the remaining + // list is sorted. Find out where to move the item. + clash := -1 + index := sort.Search(len(fi.iterators), func(n int) bool { + // The iterator always advances forward, so anything before the old slot + // is known to be behind us, so just skip them altogether. This actually + // is an important clause since the sort order got invalidated. + if n < idx { + return false + } + if n == len(fi.iterators)-1 { + // Can always place an elem last + return true + } + nextHash := fi.iterators[n+1].it.Hash() + if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 { + return true + } else if diff > 0 { + return false + } + // The elem we're placing it next to has the same value, + // so whichever winds up on n+1 will need further iteraton + clash = n + 1 + + return cur.priority < fi.iterators[n+1].priority + }) + fi.move(idx, index) + if clash != -1 { + fi.next(clash) + } + return true +} + +// move advances an iterator to another position in the list. +func (fi *fastAccountIterator) move(index, newpos int) { + elem := fi.iterators[index] + copy(fi.iterators[index:], fi.iterators[index+1:newpos+1]) + fi.iterators[newpos] = elem +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit (e.g. snapshot stack becoming stale). +func (fi *fastAccountIterator) Error() error { + return fi.fail +} + +// Hash returns the current key +func (fi *fastAccountIterator) Hash() common.Hash { + return fi.iterators[0].it.Hash() +} + +// Account returns the current key +func (fi *fastAccountIterator) Account() []byte { + return fi.curAccount +} + +// Release iterates over all the remaining live layer iterators and releases each +// of thme individually. +func (fi *fastAccountIterator) Release() { + for _, it := range fi.iterators { + it.it.Release() + } + fi.iterators = nil +} + +// Debug is a convencience helper during testing +func (fi *fastAccountIterator) Debug() { + for _, it := range fi.iterators { + fmt.Printf("[p=%v v=%v] ", it.priority, it.it.Hash()[0]) + } + fmt.Println() +} diff --git a/core/state/snapshot/iterator_test.go b/core/state/snapshot/iterator_test.go new file mode 100644 index 0000000000..613bd9955d --- /dev/null +++ b/core/state/snapshot/iterator_test.go @@ -0,0 +1,658 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "encoding/binary" + "fmt" + "math/rand" + "testing" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" +) + +// TestAccountIteratorBasics tests some simple single-layer iteration +func TestAccountIteratorBasics(t *testing.T) { + var ( + destructs = make(map[common.Hash]struct{}) + accounts = make(map[common.Hash][]byte) + storage = make(map[common.Hash]map[common.Hash][]byte) + ) + // Fill up a parent + for i := 0; i < 100; i++ { + h := randomHash() + data := randomAccount() + + accounts[h] = data + if rand.Intn(4) == 0 { + destructs[h] = struct{}{} + } + if rand.Intn(2) == 0 { + accStorage := make(map[common.Hash][]byte) + value := make([]byte, 32) + rand.Read(value) + accStorage[randomHash()] = value + storage[h] = accStorage + } + } + // Add some (identical) layers on top + parent := newDiffLayer(emptyLayer(), common.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage)) + it := parent.AccountIterator(common.Hash{}) + verifyIterator(t, 100, it) +} + +type testIterator struct { + values []byte +} + +func newTestIterator(values ...byte) *testIterator { + return &testIterator{values} +} + +func (ti *testIterator) Seek(common.Hash) { + panic("implement me") +} + +func (ti *testIterator) Next() bool { + ti.values = ti.values[1:] + return len(ti.values) > 0 +} + +func (ti *testIterator) Error() error { + return nil +} + +func (ti *testIterator) Hash() common.Hash { + return common.BytesToHash([]byte{ti.values[0]}) +} + +func (ti *testIterator) Account() []byte { + return nil +} + +func (ti *testIterator) Release() {} + +func TestFastIteratorBasics(t *testing.T) { + type testCase struct { + lists [][]byte + expKeys []byte + } + for i, tc := range []testCase{ + {lists: [][]byte{{0, 1, 8}, {1, 2, 8}, {2, 9}, {4}, + {7, 14, 15}, {9, 13, 15, 16}}, + expKeys: []byte{0, 1, 2, 4, 7, 8, 9, 13, 14, 15, 16}}, + {lists: [][]byte{{0, 8}, {1, 2, 8}, {7, 14, 15}, {8, 9}, + {9, 10}, {10, 13, 15, 16}}, + expKeys: []byte{0, 1, 2, 7, 8, 9, 10, 13, 14, 15, 16}}, + } { + var iterators []*weightedAccountIterator + for i, data := range tc.lists { + it := newTestIterator(data...) + iterators = append(iterators, &weightedAccountIterator{it, i}) + + } + fi := &fastAccountIterator{ + iterators: iterators, + initiated: false, + } + count := 0 + for fi.Next() { + if got, exp := fi.Hash()[31], tc.expKeys[count]; exp != got { + t.Errorf("tc %d, [%d]: got %d exp %d", i, count, got, exp) + } + count++ + } + } +} + +func verifyIterator(t *testing.T, expCount int, it AccountIterator) { + t.Helper() + + var ( + count = 0 + last = common.Hash{} + ) + for it.Next() { + hash := it.Hash() + if bytes.Compare(last[:], hash[:]) >= 0 { + t.Errorf("wrong order: %x >= %x", last, hash) + } + if it.Account() == nil { + t.Errorf("iterator returned nil-value for hash %x", hash) + } + count++ + } + if count != expCount { + t.Errorf("iterator count mismatch: have %d, want %d", count, expCount) + } + if err := it.Error(); err != nil { + t.Errorf("iterator failed: %v", err) + } +} + +// TestAccountIteratorTraversal tests some simple multi-layer iteration. +func TestAccountIteratorTraversal(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Stack three diff layers on top with various overlaps + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, + randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, + randomAccountSet("0xbb", "0xdd", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, + randomAccountSet("0xcc", "0xf0", "0xff"), nil) + + // Verify the single and multi-layer iterators + head := snaps.Snapshot(common.HexToHash("0x04")) + + verifyIterator(t, 3, head.(snapshot).AccountIterator(common.Hash{})) + verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator()) + + it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + defer it.Release() + + verifyIterator(t, 7, it) +} + +// TestAccountIteratorTraversalValues tests some multi-layer iteration, where we +// also expect the correct values to show up. +func TestAccountIteratorTraversalValues(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Create a batch of account sets to seed subsequent layers with + var ( + a = make(map[common.Hash][]byte) + b = make(map[common.Hash][]byte) + c = make(map[common.Hash][]byte) + d = make(map[common.Hash][]byte) + e = make(map[common.Hash][]byte) + f = make(map[common.Hash][]byte) + g = make(map[common.Hash][]byte) + h = make(map[common.Hash][]byte) + ) + for i := byte(2); i < 0xff; i++ { + a[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 0, i)) + if i > 20 && i%2 == 0 { + b[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 1, i)) + } + if i%4 == 0 { + c[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 2, i)) + } + if i%7 == 0 { + d[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 3, i)) + } + if i%8 == 0 { + e[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 4, i)) + } + if i > 50 || i < 85 { + f[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 5, i)) + } + if i%64 == 0 { + g[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 6, i)) + } + if i%128 == 0 { + h[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 7, i)) + } + } + // Assemble a stack of snapshots from the account layers + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, a, nil) + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, b, nil) + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, c, nil) + snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), nil, d, nil) + snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), nil, e, nil) + snaps.Update(common.HexToHash("0x07"), common.HexToHash("0x06"), nil, f, nil) + snaps.Update(common.HexToHash("0x08"), common.HexToHash("0x07"), nil, g, nil) + snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), nil, h, nil) + + it, _ := snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{}) + defer it.Release() + + head := snaps.Snapshot(common.HexToHash("0x09")) + for it.Next() { + hash := it.Hash() + want, err := head.AccountRLP(hash) + if err != nil { + t.Fatalf("failed to retrieve expected account: %v", err) + } + if have := it.Account(); !bytes.Equal(want, have) { + t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want) + } + } +} + +// This testcase is notorious, all layers contain the exact same 200 accounts. +func TestAccountIteratorLargeTraversal(t *testing.T) { + // Create a custom account factory to recreate the same addresses + makeAccounts := func(num int) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for i := 0; i < num; i++ { + h := common.Hash{} + binary.BigEndian.PutUint64(h[:], uint64(i+1)) + accounts[h] = randomAccount() + } + return accounts + } + // Build up a large stack of snapshots + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + for i := 1; i < 128; i++ { + snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil) + } + // Iterate the entire stack and ensure everything is hit only once + head := snaps.Snapshot(common.HexToHash("0x80")) + verifyIterator(t, 200, head.(snapshot).AccountIterator(common.Hash{})) + verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator()) + + it, _ := snaps.AccountIterator(common.HexToHash("0x80"), common.Hash{}) + defer it.Release() + + verifyIterator(t, 200, it) +} + +// TestAccountIteratorFlattening tests what happens when we +// - have a live iterator on child C (parent C1 -> C2 .. CN) +// - flattens C2 all the way into CN +// - continues iterating +func TestAccountIteratorFlattening(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Create a stack of diffs on top + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, + randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, + randomAccountSet("0xbb", "0xdd", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, + randomAccountSet("0xcc", "0xf0", "0xff"), nil) + + // Create an iterator and flatten the data from underneath it + it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + defer it.Release() + + if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil { + t.Fatalf("failed to flatten snapshot stack: %v", err) + } + //verifyIterator(t, 7, it) +} + +func TestAccountIteratorSeek(t *testing.T) { + // Create a snapshot stack with some initial data + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, + randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, + randomAccountSet("0xbb", "0xdd", "0xf0"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, + randomAccountSet("0xcc", "0xf0", "0xff"), nil) + + // Construct various iterators and ensure their tranversal is correct + it, _ := snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xdd")) + defer it.Release() + verifyIterator(t, 3, it) // expected: ee, f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xaa")) + defer it.Release() + verifyIterator(t, 3, it) // expected: ee, f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xff")) + defer it.Release() + verifyIterator(t, 0, it) // expected: nothing + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xbb")) + defer it.Release() + verifyIterator(t, 5, it) // expected: cc, dd, ee, f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xef")) + defer it.Release() + verifyIterator(t, 2, it) // expected: f0, ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xf0")) + defer it.Release() + verifyIterator(t, 1, it) // expected: ff + + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xff")) + defer it.Release() + verifyIterator(t, 0, it) // expected: nothing +} + +// TestIteratorDeletions tests that the iterator behaves correct when there are +// deleted accounts (where the Account() value is nil). The iterator +// should not output any accounts or nil-values for those cases. +func TestIteratorDeletions(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Stack three diff layers on top with various overlaps + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), + nil, randomAccountSet("0x11", "0x22", "0x33"), nil) + + deleted := common.HexToHash("0x22") + destructed := map[common.Hash]struct{}{ + deleted: struct{}{}, + } + snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), + destructed, randomAccountSet("0x11", "0x33"), nil) + + snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), + nil, randomAccountSet("0x33", "0x44", "0x55"), nil) + + // The output should be 11,33,44,55 + it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + // Do a quick check + verifyIterator(t, 4, it) + it.Release() + + // And a more detailed verification that we indeed do not see '0x22' + it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{}) + defer it.Release() + for it.Next() { + hash := it.Hash() + if it.Account() == nil { + t.Errorf("iterator returned nil-value for hash %x", hash) + } + if hash == deleted { + t.Errorf("expected deleted elem %x to not be returned by iterator", deleted) + } + } +} + +// BenchmarkAccountIteratorTraversal is a bit a bit notorious -- all layers contain the +// exact same 200 accounts. That means that we need to process 2000 items, but +// only spit out 200 values eventually. +// +// The value-fetching benchmark is easy on the binary iterator, since it never has to reach +// down at any depth for retrieving the values -- all are on the toppmost layer +// +// BenchmarkAccountIteratorTraversal/binary_iterator_keys-6 2239 483674 ns/op +// BenchmarkAccountIteratorTraversal/binary_iterator_values-6 2403 501810 ns/op +// BenchmarkAccountIteratorTraversal/fast_iterator_keys-6 1923 677966 ns/op +// BenchmarkAccountIteratorTraversal/fast_iterator_values-6 1741 649967 ns/op +func BenchmarkAccountIteratorTraversal(b *testing.B) { + // Create a custom account factory to recreate the same addresses + makeAccounts := func(num int) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for i := 0; i < num; i++ { + h := common.Hash{} + binary.BigEndian.PutUint64(h[:], uint64(i+1)) + accounts[h] = randomAccount() + } + return accounts + } + // Build up a large stack of snapshots + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + for i := 1; i <= 100; i++ { + snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil) + } + // We call this once before the benchmark, so the creation of + // sorted accountlists are not included in the results. + head := snaps.Snapshot(common.HexToHash("0x65")) + head.(*diffLayer).newBinaryAccountIterator() + + b.Run("binary iterator keys", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("binary iterator values", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + head.(*diffLayer).accountRLP(it.Hash(), 0) + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator keys", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + got++ + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator values", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + got++ + it.Account() + } + if exp := 200; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) +} + +// BenchmarkAccountIteratorLargeBaselayer is a pretty realistic benchmark, where +// the baselayer is a lot larger than the upper layer. +// +// This is heavy on the binary iterator, which in most cases will have to +// call recursively 100 times for the majority of the values +// +// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(keys)-6 514 1971999 ns/op +// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(values)-6 61 18997492 ns/op +// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(keys)-6 10000 114385 ns/op +// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(values)-6 4047 296823 ns/op +func BenchmarkAccountIteratorLargeBaselayer(b *testing.B) { + // Create a custom account factory to recreate the same addresses + makeAccounts := func(num int) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for i := 0; i < num; i++ { + h := common.Hash{} + binary.BigEndian.PutUint64(h[:], uint64(i+1)) + accounts[h] = randomAccount() + } + return accounts + } + // Build up a large stack of snapshots + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, makeAccounts(2000), nil) + for i := 2; i <= 100; i++ { + snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(20), nil) + } + // We call this once before the benchmark, so the creation of + // sorted accountlists are not included in the results. + head := snaps.Snapshot(common.HexToHash("0x65")) + head.(*diffLayer).newBinaryAccountIterator() + + b.Run("binary iterator (keys)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("binary iterator (values)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + got := 0 + it := head.(*diffLayer).newBinaryAccountIterator() + for it.Next() { + got++ + v := it.Hash() + head.(*diffLayer).accountRLP(v, 0) + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator (keys)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + got++ + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) + b.Run("fast iterator (values)", func(b *testing.B) { + for i := 0; i < b.N; i++ { + it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{}) + defer it.Release() + + got := 0 + for it.Next() { + it.Account() + got++ + } + if exp := 2000; got != exp { + b.Errorf("iterator len wrong, expected %d, got %d", exp, got) + } + } + }) +} + +/* +func BenchmarkBinaryAccountIteration(b *testing.B) { + benchmarkAccountIteration(b, func(snap snapshot) AccountIterator { + return snap.(*diffLayer).newBinaryAccountIterator() + }) +} +func BenchmarkFastAccountIteration(b *testing.B) { + benchmarkAccountIteration(b, newFastAccountIterator) +} +func benchmarkAccountIteration(b *testing.B, iterator func(snap snapshot) AccountIterator) { + // Create a diff stack and randomize the accounts across them + layers := make([]map[common.Hash][]byte, 128) + for i := 0; i < len(layers); i++ { + layers[i] = make(map[common.Hash][]byte) + } + for i := 0; i < b.N; i++ { + depth := rand.Intn(len(layers)) + layers[depth][randomHash()] = randomAccount() + } + stack := snapshot(emptyLayer()) + for _, layer := range layers { + stack = stack.Update(common.Hash{}, layer, nil, nil) + } + // Reset the timers and report all the stats + it := iterator(stack) + b.ResetTimer() + b.ReportAllocs() + for it.Next() { + } +} +*/ diff --git a/core/state/snapshot/journal.go b/core/state/snapshot/journal.go new file mode 100644 index 0000000000..0c0e3a960c --- /dev/null +++ b/core/state/snapshot/journal.go @@ -0,0 +1,243 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/rlp" + "github.com/tomochain/tomochain/trie" +) + +// journalGenerator is a disk layer entry containing the generator progress marker. +type journalGenerator struct { + Wiping bool // Whether the database was in progress of being wiped + Done bool // Whether the generator finished creating the snapshot + Marker []byte + Accounts uint64 + Slots uint64 + Storage uint64 +} + +// journalDestruct is an account deletion entry in a diffLayer's disk journal. +type journalDestruct struct { + Hash common.Hash +} + +// journalAccount is an account entry in a diffLayer's disk journal. +type journalAccount struct { + Hash common.Hash + Blob []byte +} + +// journalStorage is an account's storage map in a diffLayer's disk journal. +type journalStorage struct { + Hash common.Hash + Keys []common.Hash + Vals [][]byte +} + +// loadSnapshot loads a pre-existing state snapshot backed by a key-value store. +func loadSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) (snapshot, error) { + // Retrieve the block number and hash of the snapshot, failing if no snapshot + // is present in the database (or crashed mid-update). + baseRoot := rawdb.ReadSnapshotRoot(diskdb) + if baseRoot == (common.Hash{}) { + return nil, errors.New("missing or corrupted snapshot") + } + base := &diskLayer{ + diskdb: diskdb, + triedb: triedb, + cache: fastcache.New(cache * 1024 * 1024), + root: baseRoot, + } + // Retrieve the journal, it must exist since even for 0 layer it stores whether + // we've already generated the snapshot or are in progress only + journal := rawdb.ReadSnapshotJournal(diskdb) + if len(journal) == 0 { + return nil, errors.New("missing or corrupted snapshot journal") + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + + // Read the snapshot generation progress for the disk layer + var generator journalGenerator + if err := r.Decode(&generator); err != nil { + return nil, fmt.Errorf("failed to load snapshot progress marker: %v", err) + } + // Load all the snapshot diffs from the journal + snapshot, err := loadDiffLayer(base, r) + if err != nil { + return nil, err + } + // Entire snapshot journal loaded, sanity check the head and return + // Journal doesn't exist, don't worry if it's not supposed to + if head := snapshot.Root(); head != root { + return nil, fmt.Errorf("head doesn't match snapshot: have %#x, want %#x", head, root) + } + // Everything loaded correctly, resume any suspended operations + if !generator.Done { + // Whether or not wiping was in progress, load any generator progress too + base.genMarker = generator.Marker + if base.genMarker == nil { + base.genMarker = []byte{} + } + base.genPending = make(chan struct{}) + base.genAbort = make(chan chan *generatorStats) + + var origin uint64 + if len(generator.Marker) >= 8 { + origin = binary.BigEndian.Uint64(generator.Marker) + } + go base.generate(&generatorStats{ + origin: origin, + start: time.Now(), + accounts: generator.Accounts, + slots: generator.Slots, + storage: common.StorageSize(generator.Storage), + }) + } + return snapshot, nil +} + +// loadDiffLayer reads the next sections of a snapshot journal, reconstructing a new +// diff and verifying that it can be linked to the requested parent. +func loadDiffLayer(parent snapshot, r *rlp.Stream) (snapshot, error) { + // Read the next diff journal entry + var root common.Hash + if err := r.Decode(&root); err != nil { + // The first read may fail with EOF, marking the end of the journal + if err == io.EOF { + return parent, nil + } + return nil, fmt.Errorf("load diff root: %v", err) + } + var destructs []journalDestruct + if err := r.Decode(&destructs); err != nil { + return nil, fmt.Errorf("load diff destructs: %v", err) + } + destructSet := make(map[common.Hash]struct{}) + for _, entry := range destructs { + destructSet[entry.Hash] = struct{}{} + } + var accounts []journalAccount + if err := r.Decode(&accounts); err != nil { + return nil, fmt.Errorf("load diff accounts: %v", err) + } + accountData := make(map[common.Hash][]byte) + for _, entry := range accounts { + accountData[entry.Hash] = entry.Blob + } + var storage []journalStorage + if err := r.Decode(&storage); err != nil { + return nil, fmt.Errorf("load diff storage: %v", err) + } + storageData := make(map[common.Hash]map[common.Hash][]byte) + for _, entry := range storage { + slots := make(map[common.Hash][]byte) + for i, key := range entry.Keys { + slots[key] = entry.Vals[i] + } + storageData[entry.Hash] = slots + } + return loadDiffLayer(newDiffLayer(parent, root, destructSet, accountData, storageData), r) +} + +// Journal writes the persistent layer generator stats into a buffer to be stored +// in the database as the snapshot journal. +func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + // If the snapshot is currently being generated, abort it + var stats *generatorStats + if dl.genAbort != nil { + abort := make(chan *generatorStats) + dl.genAbort <- abort + + if stats = <-abort; stats != nil { + stats.Log("Journalling in-progress snapshot", dl.genMarker) + } + } + // Ensure the layer didn't get stale + dl.lock.RLock() + defer dl.lock.RUnlock() + + if dl.stale { + return common.Hash{}, ErrSnapshotStale + } + // Ensure the generator stats is written even if none was ran this cycle + journalProgress(dl.diskdb, dl.genMarker, stats) + + log.Debug("Journalled disk layer", "root", dl.root) + return dl.root, nil +} + +// Journal writes the memory layer contents into a buffer to be stored in the +// database as the snapshot journal. +func (dl *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + // Journal the parent first + base, err := dl.parent.Journal(buffer) + if err != nil { + return common.Hash{}, err + } + // Ensure the layer didn't get stale + dl.lock.RLock() + defer dl.lock.RUnlock() + + if dl.Stale() { + return common.Hash{}, ErrSnapshotStale + } + // Everything below was journalled, persist this layer too + if err := rlp.Encode(buffer, dl.root); err != nil { + return common.Hash{}, err + } + destructs := make([]journalDestruct, 0, len(dl.destructSet)) + for hash := range dl.destructSet { + destructs = append(destructs, journalDestruct{Hash: hash}) + } + if err := rlp.Encode(buffer, destructs); err != nil { + return common.Hash{}, err + } + accounts := make([]journalAccount, 0, len(dl.accountData)) + for hash, blob := range dl.accountData { + accounts = append(accounts, journalAccount{Hash: hash, Blob: blob}) + } + if err := rlp.Encode(buffer, accounts); err != nil { + return common.Hash{}, err + } + storage := make([]journalStorage, 0, len(dl.storageData)) + for hash, slots := range dl.storageData { + keys := make([]common.Hash, 0, len(slots)) + vals := make([][]byte, 0, len(slots)) + for key, val := range slots { + keys = append(keys, key) + vals = append(vals, val) + } + storage = append(storage, journalStorage{Hash: hash, Keys: keys, Vals: vals}) + } + if err := rlp.Encode(buffer, storage); err != nil { + return common.Hash{}, err + } + return base, nil +} diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go new file mode 100644 index 0000000000..82cc1addce --- /dev/null +++ b/core/state/snapshot/snapshot.go @@ -0,0 +1,598 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package snapshot implements a journalled, dynamic state dump. +package snapshot + +import ( + "bytes" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/metrics" + "github.com/tomochain/tomochain/trie" +) + +var ( + snapshotCleanAccountHitMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/hit", nil) + snapshotCleanAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/miss", nil) + snapshotCleanAccountInexMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/inex", nil) + snapshotCleanAccountReadMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/read", nil) + snapshotCleanAccountWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/account/write", nil) + + snapshotCleanStorageHitMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/hit", nil) + snapshotCleanStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/miss", nil) + snapshotCleanStorageInexMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/inex", nil) + snapshotCleanStorageReadMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/read", nil) + snapshotCleanStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/storage/write", nil) + + snapshotDirtyAccountHitMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/hit", nil) + snapshotDirtyAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/miss", nil) + snapshotDirtyAccountInexMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/inex", nil) + snapshotDirtyAccountReadMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/read", nil) + snapshotDirtyAccountWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/account/write", nil) + + snapshotDirtyStorageHitMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/hit", nil) + snapshotDirtyStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/miss", nil) + snapshotDirtyStorageInexMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/inex", nil) + snapshotDirtyStorageReadMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/read", nil) + snapshotDirtyStorageWriteMeter = metrics.NewRegisteredMeter("state/snapshot/dirty/storage/write", nil) + + snapshotDirtyAccountHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/account/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015)) + snapshotDirtyStorageHitDepthHist = metrics.NewRegisteredHistogram("state/snapshot/dirty/storage/hit/depth", nil, metrics.NewExpDecaySample(1028, 0.015)) + + snapshotFlushAccountItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/item", nil) + snapshotFlushAccountSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/account/size", nil) + snapshotFlushStorageItemMeter = metrics.NewRegisteredMeter("state/snapshot/flush/storage/item", nil) + snapshotFlushStorageSizeMeter = metrics.NewRegisteredMeter("state/snapshot/flush/storage/size", nil) + + snapshotBloomIndexTimer = metrics.NewRegisteredResettingTimer("state/snapshot/bloom/index", nil) + snapshotBloomErrorGauge = metrics.NewRegisteredGaugeFloat64("state/snapshot/bloom/error", nil) + + snapshotBloomAccountTrueHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/truehit", nil) + snapshotBloomAccountFalseHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/falsehit", nil) + snapshotBloomAccountMissMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/account/miss", nil) + + snapshotBloomStorageTrueHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/truehit", nil) + snapshotBloomStorageFalseHitMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/falsehit", nil) + snapshotBloomStorageMissMeter = metrics.NewRegisteredMeter("state/snapshot/bloom/storage/miss", nil) + + // ErrSnapshotStale is returned from data accessors if the underlying snapshot + // layer had been invalidated due to the chain progressing forward far enough + // to not maintain the layer's original state. + ErrSnapshotStale = errors.New("snapshot stale") + + // ErrNotCoveredYet is returned from data accessors if the underlying snapshot + // is being generated currently and the requested data item is not yet in the + // range of accounts covered. + ErrNotCoveredYet = errors.New("not covered yet") + + // errSnapshotCycle is returned if a snapshot is attempted to be inserted + // that forms a cycle in the snapshot tree. + errSnapshotCycle = errors.New("snapshot cycle") +) + +// Snapshot represents the functionality supported by a snapshot storage layer. +type Snapshot interface { + // Root returns the root hash for which this snapshot was made. + Root() common.Hash + + // Account directly retrieves the account associated with a particular hash in + // the snapshot slim data format. + Account(hash common.Hash) (*types.SlimAccount, error) + + // AccountRLP directly retrieves the account RLP associated with a particular + // hash in the snapshot slim data format. + AccountRLP(hash common.Hash) ([]byte, error) + + // Storage directly retrieves the storage data associated with a particular hash, + // within a particular account. + Storage(accountHash, storageHash common.Hash) ([]byte, error) +} + +// snapshot is the internal version of the snapshot data layer that supports some +// additional methods compared to the public API. +type snapshot interface { + Snapshot + + // Parent returns the subsequent layer of a snapshot, or nil if the base was + // reached. + // + // Note, the method is an internal helper to avoid type switching between the + // disk and diff layers. There is no locking involved. + Parent() snapshot + + // Update creates a new layer on top of the existing snapshot diff tree with + // the specified data items. + // + // Note, the maps are retained by the method to avoid copying everything. + Update(blockRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer + + // Journal commits an entire diff hierarchy to disk into a single journal entry. + // This is meant to be used during shutdown to persist the snapshot without + // flattening everything down (bad for reorgs). + Journal(buffer *bytes.Buffer) (common.Hash, error) + + // Stale return whether this layer has become stale (was flattened across) or + // if it's still live. + Stale() bool + + // AccountIterator creates an account iterator over an arbitrary layer. + AccountIterator(seek common.Hash) AccountIterator +} + +// SnapshotTree is an Ethereum state snapshot tree. It consists of one persistent +// base layer backed by a key-value store, on top of which arbitrarily many in- +// memory diff layers are topped. The memory diffs can form a tree with branching, +// but the disk layer is singleton and common to all. If a reorg goes deeper than +// the disk layer, everything needs to be deleted. +// +// The goal of a state snapshot is twofold: to allow direct access to account and +// storage data to avoid expensive multi-level trie lookups; and to allow sorted, +// cheap iteration of the account/storage tries for sync aid. +type Tree struct { + diskdb ethdb.KeyValueStore // Persistent database to store the snapshot + triedb *trie.Database // In-memory cache to access the trie through + cache int // Megabytes permitted to use for read caches + layers map[common.Hash]snapshot // Collection of all known layers + lock sync.RWMutex +} + +// New attempts to load an already existing snapshot from a persistent key-value +// store (with a number of memory layers from a journal), ensuring that the head +// of the snapshot matches the expected one. +// +// If the snapshot is missing or inconsistent, the entirety is deleted and will +// be reconstructed from scratch based on the tries in the key-value store, on a +// background thread. +func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash, async bool) *Tree { + // Create a new, empty snapshot tree + snap := &Tree{ + diskdb: diskdb, + triedb: triedb, + cache: cache, + layers: make(map[common.Hash]snapshot), + } + if !async { + defer snap.waitBuild() + } + // Attempt to load a previously persisted snapshot and rebuild one if failed + head, err := loadSnapshot(diskdb, triedb, cache, root) + if err != nil { + log.Warn("Failed to load snapshot, regenerating", "err", err) + snap.Rebuild(root) + return snap + } + // Existing snapshot loaded, seed all the layers + for head != nil { + snap.layers[head.Root()] = head + head = head.Parent() + } + return snap +} + +// waitBuild blocks until the snapshot finishes rebuilding. This method is meant +// to be used by tests to ensure we're testing what we believe we are. +func (t *Tree) waitBuild() { + // Find the rebuild termination channel + var done chan struct{} + + t.lock.RLock() + for _, layer := range t.layers { + if layer, ok := layer.(*diskLayer); ok { + done = layer.genPending + break + } + } + t.lock.RUnlock() + + // Wait until the snapshot is generated + if done != nil { + <-done + } +} + +// Snapshot retrieves a snapshot belonging to the given block root, or nil if no +// snapshot is maintained for that block. +func (t *Tree) Snapshot(blockRoot common.Hash) Snapshot { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.layers[blockRoot] +} + +// Update adds a new snapshot into the tree, if that can be linked to an existing +// old parent. It is disallowed to insert a disk layer (the origin of all). +func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs map[common.Hash]struct{}, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) error { + // Reject noop updates to avoid self-loops in the snapshot tree. This is a + // special case that can only happen for Clique networks where empty blocks + // don't modify the state (0 block subsidy). + // + // Although we could silently ignore this internally, it should be the caller's + // responsibility to avoid even attempting to insert such a snapshot. + if blockRoot == parentRoot { + return errSnapshotCycle + } + // Generate a new snapshot on top of the parent + parent := t.Snapshot(parentRoot).(snapshot) + if parent == nil { + return fmt.Errorf("parent [%#x] snapshot missing", parentRoot) + } + snap := parent.Update(blockRoot, destructs, accounts, storage) + + // Save the new snapshot for later + t.lock.Lock() + defer t.lock.Unlock() + + t.layers[snap.root] = snap + return nil +} + +// Cap traverses downwards the snapshot tree from a head block hash until the +// number of allowed layers are crossed. All layers beyond the permitted number +// are flattened downwards. +func (t *Tree) Cap(root common.Hash, layers int) error { + // Retrieve the head snapshot to cap from + snap := t.Snapshot(root) + if snap == nil { + return fmt.Errorf("snapshot [%#x] missing", root) + } + diff, ok := snap.(*diffLayer) + if !ok { + return fmt.Errorf("snapshot [%#x] is disk layer", root) + } + // Run the internal capping and discard all stale layers + t.lock.Lock() + defer t.lock.Unlock() + + // Flattening the bottom-most diff layer requires special casing since there's + // no child to rewire to the grandparent. In that case we can fake a temporary + // child for the capping and then remove it. + var persisted *diskLayer + + switch layers { + case 0: + // If full commit was requested, flatten the diffs and merge onto disk + diff.lock.RLock() + base := diffToDisk(diff.flatten().(*diffLayer)) + diff.lock.RUnlock() + + // Replace the entire snapshot tree with the flat base + t.layers = map[common.Hash]snapshot{base.root: base} + return nil + + case 1: + // If full flattening was requested, flatten the diffs but only merge if the + // memory limit was reached + var ( + bottom *diffLayer + base *diskLayer + ) + diff.lock.RLock() + bottom = diff.flatten().(*diffLayer) + if bottom.memory >= aggregatorMemoryLimit { + base = diffToDisk(bottom) + } + diff.lock.RUnlock() + + // If all diff layers were removed, replace the entire snapshot tree + if base != nil { + t.layers = map[common.Hash]snapshot{base.root: base} + return nil + } + // Merge the new aggregated layer into the snapshot tree, clean stales below + t.layers[bottom.root] = bottom + + default: + // Many layers requested to be retained, cap normally + persisted = t.cap(diff, layers) + } + // Remove any layer that is stale or links into a stale layer + children := make(map[common.Hash][]common.Hash) + for root, snap := range t.layers { + if diff, ok := snap.(*diffLayer); ok { + parent := diff.parent.Root() + children[parent] = append(children[parent], root) + } + } + var remove func(root common.Hash) + remove = func(root common.Hash) { + delete(t.layers, root) + for _, child := range children[root] { + remove(child) + } + delete(children, root) + } + for root, snap := range t.layers { + if snap.Stale() { + remove(root) + } + } + // If the disk layer was modified, regenerate all the cummulative blooms + if persisted != nil { + var rebloom func(root common.Hash) + rebloom = func(root common.Hash) { + if diff, ok := t.layers[root].(*diffLayer); ok { + diff.rebloom(persisted) + } + for _, child := range children[root] { + rebloom(child) + } + } + rebloom(persisted.root) + } + return nil +} + +// cap traverses downwards the diff tree until the number of allowed layers are +// crossed. All diffs beyond the permitted number are flattened downwards. If the +// layer limit is reached, memory cap is also enforced (but not before). +// +// The method returns the new disk layer if diffs were persistend into it. +func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer { + // Dive until we run out of layers or reach the persistent database + for ; layers > 2; layers-- { + // If we still have diff layers below, continue down + if parent, ok := diff.parent.(*diffLayer); ok { + diff = parent + } else { + // Diff stack too shallow, return without modifications + return nil + } + } + // We're out of layers, flatten anything below, stopping if it's the disk or if + // the memory limit is not yet exceeded. + switch parent := diff.parent.(type) { + case *diskLayer: + return nil + + case *diffLayer: + // Flatten the parent into the grandparent. The flattening internally obtains a + // write lock on grandparent. + flattened := parent.flatten().(*diffLayer) + t.layers[flattened.root] = flattened + + diff.lock.Lock() + defer diff.lock.Unlock() + + diff.parent = flattened + if flattened.memory < aggregatorMemoryLimit { + // Accumulator layer is smaller than the limit, so we can abort, unless + // there's a snapshot being generated currently. In that case, the trie + // will move fron underneath the generator so we **must** merge all the + // partial data down into the snapshot and restart the generation. + if flattened.parent.(*diskLayer).genAbort == nil { + return nil + } + } + default: + panic(fmt.Sprintf("unknown data layer: %T", parent)) + } + // If the bottom-most layer is larger than our memory cap, persist to disk + bottom := diff.parent.(*diffLayer) + + bottom.lock.RLock() + base := diffToDisk(bottom) + bottom.lock.RUnlock() + + t.layers[base.root] = base + diff.parent = base + return base +} + +// diffToDisk merges a bottom-most diff into the persistent disk layer underneath +// it. The method will panic if called onto a non-bottom-most diff layer. +func diffToDisk(bottom *diffLayer) *diskLayer { + var ( + base = bottom.parent.(*diskLayer) + batch = base.diskdb.NewBatch() + stats *generatorStats + ) + // If the disk layer is running a snapshot generator, abort it + if base.genAbort != nil { + abort := make(chan *generatorStats) + base.genAbort <- abort + stats = <-abort + } + // Start by temporarily deleting the current snapshot block marker. This + // ensures that in the case of a crash, the entire snapshot is invalidated. + rawdb.DeleteSnapshotRoot(batch) + + // Mark the original base as stale as we're going to create a new wrapper + base.lock.Lock() + if base.stale { + panic("parent disk layer is stale") // we've committed into the same base from two children, boo + } + base.stale = true + base.lock.Unlock() + + // Destroy all the destructed accounts from the database + for hash := range bottom.destructSet { + // Skip any account not covered yet by the snapshot + if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 { + continue + } + // Remove all storage slots + rawdb.DeleteAccountSnapshot(batch, hash) + base.cache.Set(hash[:], nil) + + it := rawdb.IterateStorageSnapshots(base.diskdb, hash) + for it.Next() { + if key := it.Key(); len(key) == 65 { // TODO(karalabe): Yuck, we should move this into the iterator + batch.Delete(key) + base.cache.Del(key[1:]) + + snapshotFlushStorageItemMeter.Mark(1) + } + } + it.Release() + } + // Push all updated accounts into the database + for hash, data := range bottom.accountData { + // Skip any account not covered yet by the snapshot + if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 { + continue + } + // Push the account to disk + rawdb.WriteAccountSnapshot(batch, hash, data) + base.cache.Set(hash[:], data) + snapshotCleanAccountWriteMeter.Mark(int64(len(data))) + + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Crit("Failed to write account snapshot", "err", err) + } + batch.Reset() + } + snapshotFlushAccountItemMeter.Mark(1) + snapshotFlushAccountSizeMeter.Mark(int64(len(data))) + } + // Push all the storage slots into the database + for accountHash, storage := range bottom.storageData { + // Skip any account not covered yet by the snapshot + if base.genMarker != nil && bytes.Compare(accountHash[:], base.genMarker) > 0 { + continue + } + // Generation might be mid-account, track that case too + midAccount := base.genMarker != nil && bytes.Equal(accountHash[:], base.genMarker[:common.HashLength]) + + for storageHash, data := range storage { + // Skip any slot not covered yet by the snapshot + if midAccount && bytes.Compare(storageHash[:], base.genMarker[common.HashLength:]) > 0 { + continue + } + if len(data) > 0 { + rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, data) + base.cache.Set(append(accountHash[:], storageHash[:]...), data) + snapshotCleanStorageWriteMeter.Mark(int64(len(data))) + } else { + rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash) + base.cache.Set(append(accountHash[:], storageHash[:]...), nil) + } + snapshotFlushStorageItemMeter.Mark(1) + snapshotFlushStorageSizeMeter.Mark(int64(len(data))) + } + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + log.Crit("Failed to write storage snapshot", "err", err) + } + batch.Reset() + } + } + // Update the snapshot block marker and write any remainder data + rawdb.WriteSnapshotRoot(batch, bottom.root) + if err := batch.Write(); err != nil { + log.Crit("Failed to write leftover snapshot", "err", err) + } + res := &diskLayer{ + root: bottom.root, + cache: base.cache, + diskdb: base.diskdb, + triedb: base.triedb, + genMarker: base.genMarker, + genPending: base.genPending, + } + // If snapshot generation hasn't finished yet, port over all the starts and + // continue where the previous round left off. + // + // Note, the `base.genAbort` comparison is not used normally, it's checked + // to allow the tests to play with the marker without triggering this path. + if base.genMarker != nil && base.genAbort != nil { + res.genMarker = base.genMarker + res.genAbort = make(chan chan *generatorStats) + go res.generate(stats) + } + return res +} + +// Journal commits an entire diff hierarchy to disk into a single journal entry. +// This is meant to be used during shutdown to persist the snapshot without +// flattening everything down (bad for reorgs). +// +// The method returns the root hash of the base layer that needs to be persisted +// to disk as a trie too to allow continuing any pending generation op. +func (t *Tree) Journal(root common.Hash) (common.Hash, error) { + // Retrieve the head snapshot to journal from var snap snapshot + snap := t.Snapshot(root) + if snap == nil { + return common.Hash{}, fmt.Errorf("snapshot [%#x] missing", root) + } + // Run the journaling + t.lock.Lock() + defer t.lock.Unlock() + + journal := new(bytes.Buffer) + base, err := snap.(snapshot).Journal(journal) + if err != nil { + return common.Hash{}, err + } + // Store the journal into the database and return + rawdb.WriteSnapshotJournal(t.diskdb, journal.Bytes()) + return base, nil +} + +// Rebuild wipes all available snapshot data from the persistent database and +// discard all caches and diff layers. Afterwards, it starts a new snapshot +// generator with the given root hash. +func (t *Tree) Rebuild(root common.Hash) { + t.lock.Lock() + defer t.lock.Unlock() + + // Iterate over and mark all layers stale + for _, layer := range t.layers { + switch layer := layer.(type) { + case *diskLayer: + // If the base layer is generating, abort it and save + if layer.genAbort != nil { + abort := make(chan *generatorStats) + layer.genAbort <- abort + <-abort + } + // Layer should be inactive now, mark it as stale + layer.lock.Lock() + layer.stale = true + layer.lock.Unlock() + + case *diffLayer: + // If the layer is a simple diff, simply mark as stale + layer.lock.Lock() + atomic.StoreUint32(&layer.stale, 1) + layer.lock.Unlock() + + default: + panic(fmt.Sprintf("unknown layer type: %T", layer)) + } + } + // Start generating a new snapshot from scratch on a backgroung thread. The + // generator will run a wiper first if there's not one running right now. + log.Info("Rebuilding state snapshot") + t.layers = map[common.Hash]snapshot{ + root: generateSnapshot(t.diskdb, t.triedb, t.cache, root), + } +} + +// AccountIterator creates a new account iterator for the specified root hash and +// seeks to a starting account hash. +func (t *Tree) AccountIterator(root common.Hash, seek common.Hash) (AccountIterator, error) { + return newFastAccountIterator(t, root, seek) +} diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go new file mode 100644 index 0000000000..35fe62c839 --- /dev/null +++ b/core/state/snapshot/snapshot_test.go @@ -0,0 +1,350 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "fmt" + "math/big" + "math/rand" + "testing" + + "github.com/VictoriaMetrics/fastcache" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/rlp" +) + +// randomHash generates a random blob of data and returns it as a hash. +func randomHash() common.Hash { + var hash common.Hash + if n, err := rand.Read(hash[:]); n != common.HashLength || err != nil { + panic(err) + } + return hash +} + +// randomAccount generates a random account and returns it RLP encoded. +func randomAccount() []byte { + root := randomHash() + a := types.SlimAccount{ + Balance: big.NewInt(rand.Int63()), + Nonce: rand.Uint64(), + Root: root[:], + CodeHash: emptyCode[:], + } + data, _ := rlp.EncodeToBytes(a) + return data +} + +// randomAccountSet generates a set of random accounts with the given strings as +// the account address hashes. +func randomAccountSet(hashes ...string) map[common.Hash][]byte { + accounts := make(map[common.Hash][]byte) + for _, hash := range hashes { + accounts[common.HexToHash(hash)] = randomAccount() + } + return accounts +} + +// Tests that if a disk layer becomes stale, no active external references will +// be returned with junk data. This version of the test flattens every diff layer +// to check internal corner case around the bottom-most memory accumulator. +func TestDiskLayerExternalInvalidationFullFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Retrieve a reference to the base and commit a diff on top + ref := snaps.Snapshot(base.root) + + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 2) + } + // Commit the diff layer onto the disk and ensure it's persisted + if err := snaps.Cap(common.HexToHash("0x02"), 0); err != nil { + t.Fatalf("failed to merge diff layer onto disk: %v", err) + } + // Since the base layer was modified, ensure that data retrieval on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 1 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 1) + fmt.Println(snaps.layers) + } +} + +// Tests that if a disk layer becomes stale, no active external references will +// be returned with junk data. This version of the test retains the bottom diff +// layer to check the usual mode of operation where the accumulator is retained. +func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Retrieve a reference to the base and commit two diffs on top + ref := snaps.Snapshot(base.root) + + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) + } + // Commit the diff layer onto the disk and ensure it's persisted + defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit) + aggregatorMemoryLimit = 0 + + if err := snaps.Cap(common.HexToHash("0x03"), 2); err != nil { + t.Fatalf("failed to merge diff layer onto disk: %v", err) + } + // Since the base layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) + fmt.Println(snaps.layers) + } +} + +// Tests that if a diff layer becomes stale, no active external references will +// be returned with junk data. This version of the test flattens every diff layer +// to check internal corner case around the bottom-most memory accumulator. +func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Commit two diffs on top and retrieve a reference to the bottommost + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3) + } + ref := snaps.Snapshot(common.HexToHash("0x02")) + + // Flatten the diff layer into the bottom accumulator + if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 2 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2) + fmt.Println(snaps.layers) + } +} + +// Tests that if a diff layer becomes stale, no active external references will +// be returned with junk data. This version of the test retains the bottom diff +// layer to check the usual mode of operation where the accumulator is retained. +func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) { + // Create an empty base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // Commit three diffs on top and retrieve a reference to the bottommost + accounts := map[common.Hash][]byte{ + common.HexToHash("0xa1"): randomAccount(), + } + if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), nil, accounts, nil); err != nil { + t.Fatalf("failed to create a diff layer: %v", err) + } + if n := len(snaps.layers); n != 4 { + t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 4) + } + ref := snaps.Snapshot(common.HexToHash("0x02")) + + // Doing a Cap operation with many allowed layers should be a no-op + exp := len(snaps.layers) + if err := snaps.Cap(common.HexToHash("0x04"), 2000); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + if got := len(snaps.layers); got != exp { + t.Errorf("layers modified, got %d exp %d", got, exp) + } + // Flatten the diff layer into the bottom accumulator + if err := snaps.Cap(common.HexToHash("0x04"), 2); err != nil { + t.Fatalf("failed to flatten diff layer into accumulator: %v", err) + } + // Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail + if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale { + t.Errorf("stale reference returned account: %#x (err: %v)", acc, err) + } + if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale { + t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err) + } + if n := len(snaps.layers); n != 3 { + t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 3) + fmt.Println(snaps.layers) + } +} + +// TestPostCapBasicDataAccess tests some functionality regarding capping/flattening. +func TestPostCapBasicDataAccess(t *testing.T) { + // setAccount is a helper to construct a random account entry and assign it to + // an account slot in a snapshot + setAccount := func(accKey string) map[common.Hash][]byte { + return map[common.Hash][]byte{ + common.HexToHash(accKey): randomAccount(), + } + } + // Create a starting base layer and a snapshot tree out of it + base := &diskLayer{ + diskdb: rawdb.NewMemoryDatabase(), + root: common.HexToHash("0x01"), + cache: fastcache.New(1024 * 500), + } + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + base.root: base, + }, + } + // The lowest difflayer + snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil) + snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil) + snaps.Update(common.HexToHash("0xb2"), common.HexToHash("0xa1"), nil, setAccount("0xb2"), nil) + + snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil) + snaps.Update(common.HexToHash("0xb3"), common.HexToHash("0xb2"), nil, setAccount("0xb3"), nil) + + // checkExist verifies if an account exiss in a snapshot + checkExist := func(layer *diffLayer, key string) error { + if data, _ := layer.Account(common.HexToHash(key)); data == nil { + return fmt.Errorf("expected %x to exist, got nil", common.HexToHash(key)) + } + return nil + } + // shouldErr checks that an account access errors as expected + shouldErr := func(layer *diffLayer, key string) error { + if data, err := layer.Account(common.HexToHash(key)); err == nil { + return fmt.Errorf("expected error, got data %x", data) + } + return nil + } + // check basics + snap := snaps.Snapshot(common.HexToHash("0xb3")).(*diffLayer) + + if err := checkExist(snap, "0xa1"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb2"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb3"); err != nil { + t.Error(err) + } + // Cap to a bad root should fail + if err := snaps.Cap(common.HexToHash("0x1337"), 0); err == nil { + t.Errorf("expected error, got none") + } + // Now, merge the a-chain + snaps.Cap(common.HexToHash("0xa3"), 0) + + // At this point, a2 got merged into a1. Thus, a1 is now modified, and as a1 is + // the parent of b2, b2 should no longer be able to iterate into parent. + + // These should still be accessible + if err := checkExist(snap, "0xb2"); err != nil { + t.Error(err) + } + if err := checkExist(snap, "0xb3"); err != nil { + t.Error(err) + } + // But these would need iteration into the modified parent + if err := shouldErr(snap, "0xa1"); err != nil { + t.Error(err) + } + if err := shouldErr(snap, "0xa2"); err != nil { + t.Error(err) + } + if err := shouldErr(snap, "0xa3"); err != nil { + t.Error(err) + } + // Now, merge it again, just for fun. It should now error, since a3 + // is a disk layer + if err := snaps.Cap(common.HexToHash("0xa3"), 0); err == nil { + t.Error("expected error capping the disk layer, got none") + } +} diff --git a/core/state/snapshot/sort.go b/core/state/snapshot/sort.go new file mode 100644 index 0000000000..dc877911a1 --- /dev/null +++ b/core/state/snapshot/sort.go @@ -0,0 +1,36 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "bytes" + + "github.com/tomochain/tomochain/common" +) + +// hashes is a helper to implement sort.Interface. +type hashes []common.Hash + +// Len is the number of elements in the collection. +func (hs hashes) Len() int { return len(hs) } + +// Less reports whether the element with index i should sort before the element +// with index j. +func (hs hashes) Less(i, j int) bool { return bytes.Compare(hs[i][:], hs[j][:]) < 0 } + +// Swap swaps the elements with indexes i and j. +func (hs hashes) Swap(i, j int) { hs[i], hs[j] = hs[j], hs[i] } diff --git a/core/state/state_object.go b/core/state/state_object.go index b03231e23b..b0dc8da14d 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -21,9 +21,12 @@ import ( "fmt" "io" "math/big" + "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" ) @@ -31,23 +34,23 @@ var emptyCodeHash = crypto.Keccak256(nil) type Code []byte -func (self Code) String() string { - return string(self) //strings.Join(Disassemble(self), " ") +func (c Code) String() string { + return string(c) //strings.Join(Disassemble(c), " ") } type Storage map[common.Hash]common.Hash -func (self Storage) String() (str string) { - for key, value := range self { +func (s Storage) String() (str string) { + for key, value := range s { str += fmt.Sprintf("%X : %X\n", key, value) } return } -func (self Storage) Copy() Storage { +func (s Storage) Copy() Storage { cpy := make(Storage) - for key, value := range self { + for key, value := range s { cpy[key] = value } @@ -63,7 +66,7 @@ func (self Storage) Copy() Storage { type stateObject struct { address common.Address addrHash common.Hash // hash of ethereum address of the account - data Account + data types.StateAccount db *StateDB // DB error. @@ -77,17 +80,17 @@ type stateObject struct { trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - cachedStorage Storage // Storage entry cache to avoid duplicate reads - dirtyStorage Storage // Storage entries that need to be flushed to disk + originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction + pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block + dirtyStorage Storage // Storage entries that have been modified in the current transaction execution + fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. // Cache flags. - // When an object is marked suicided it will be delete from the trie + // When an object is marked suicided it will be deleted from the trie // during the "update" phase of the state transition. dirtyCode bool // true if the code was updated suicided bool - touched bool deleted bool - onDirty func(addr common.Address) // Callback method to mark a state object newly dirty } // empty returns whether the account is considered empty. @@ -95,231 +98,329 @@ func (s *stateObject) empty() bool { return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) } -// Account is the Ethereum consensus representation of accounts. -// These objects are stored in the main account trie. -type Account struct { - Nonce uint64 - Balance *big.Int - Root common.Hash // merkle root of the storage trie - CodeHash []byte -} - // newObject creates a state object. -func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *stateObject { +func newObject(db *StateDB, address common.Address, data *types.StateAccount) *stateObject { if data.Balance == nil { data.Balance = new(big.Int) } if data.CodeHash == nil { data.CodeHash = emptyCodeHash } + if data.Root == (common.Hash{}) { + data.Root = emptyRoot + } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: data, - cachedStorage: make(Storage), - dirtyStorage: make(Storage), - onDirty: onDirty, + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: *data, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), } } // EncodeRLP implements rlp.Encoder. -func (c *stateObject) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, c.data) +func (s *stateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, s.data) } // setError remembers the first non-nil error it is called with. -func (self *stateObject) setError(err error) { - if self.dbErr == nil { - self.dbErr = err +func (s *stateObject) setError(err error) { + if s.dbErr == nil { + s.dbErr = err } } -func (self *stateObject) markSuicided() { - self.suicided = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } +func (s *stateObject) markSuicided() { + s.suicided = true } -func (c *stateObject) touch() { - c.db.journal = append(c.db.journal, touchChange{ - account: &c.address, - prev: c.touched, - prevDirty: c.onDirty == nil, +func (s *stateObject) touch() { + s.db.journal.append(touchChange{ + account: &s.address, }) - if c.onDirty != nil { - c.onDirty(c.Address()) - c.onDirty = nil + if s.address == ripemd { + // Explicitly put it in the dirty-cache, which is otherwise generated from + // flattened journals. + s.db.journal.dirty(s.address) } - c.touched = true } -func (c *stateObject) getTrie(db Database) Trie { - if c.trie == nil { +func (s *stateObject) getTrie(db Database) Trie { + if s.trie == nil { var err error - c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root) + s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) if err != nil { - c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{}) - c.setError(fmt.Errorf("can't create storage trie: %v", err)) + s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } - return c.trie + return s.trie } -func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { - value := common.Hash{} - // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) - if err != nil { - self.setError(err) - return common.Hash{} +// GetState retrieves a value from the account storage trie. +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] } - if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - self.setError(err) - } - value.SetBytes(content) + // If we have a dirty value for this state entry, return it + value, dirty := s.dirtyStorage[key] + if dirty { + return value } - return value + // Otherwise return the entry's original value + return s.GetCommittedState(db, key) } -func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { - value, exists := self.cachedStorage[key] - if exists { +// GetCommittedState retrieves a value from the committed account storage trie. +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] + } + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { return value } - // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) - if err != nil { - self.setError(err) - return common.Hash{} + if value, cached := s.originStorage[key]; cached { + return value } - if len(enc) > 0 { - _, content, _, err := rlp.Split(enc) - if err != nil { - self.setError(err) + // If no live objects are available, attempt to use snapshots + var ( + enc []byte + err error + value common.Hash + ) + if s.db.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now()) + } + // If the object was destructed in *this* block (and potentially resurrected), + // the storage has been cleared out, and we should *not* consult the previous + // snapshot about any storage values. The only possible alternatives are: + // 1) resurrect happened, and new slot values were set -- those should + // have been handles via pendingStorage above. + // 2) we don't have new values, and can deliver empty response back + if _, destructed := s.db.snapDestructs[s.addrHash]; destructed { + return common.Hash{} + } + enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:])) + if len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + s.setError(err) + } + value.SetBytes(content) } - value.SetBytes(content) } - if (value != common.Hash{}) { - self.cachedStorage[key] = value + // If snapshot unavailable or reading from it failed, load from the database + if s.db.snap == nil || err != nil { + start := time.Now() + val, err := s.getTrie(db).GetStorage(s.address, key.Bytes()) + if metrics.EnabledExpensive { + s.db.StorageReads += time.Since(start) + } + if err != nil { + s.setError(err) + return common.Hash{} + } + value.SetBytes(val) } + s.originStorage[key] = value return value } // SetState updates a value in account storage. -func (self *stateObject) SetState(db Database, key, value common.Hash) { - self.db.journal = append(self.db.journal, storageChange{ - account: &self.address, +func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the fake storage is set, put the temporary state update here. + if s.fakeStorage != nil { + s.fakeStorage[key] = value + return + } + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + // New value is different, update and journal the change + s.db.journal.append(storageChange{ + account: &s.address, key: key, - prevalue: self.GetState(db, key), + prevalue: prev, }) - self.setState(key, value) + s.setState(key, value) } -func (self *stateObject) setState(key, value common.Hash) { - self.cachedStorage[key] = value - self.dirtyStorage[key] = value +// SetStorage replaces the entire state storage with the given one. +// +// After this function is called, all original state will be ignored and state +// lookup only happens in the fake state storage. +// +// Note this function should only be used for debugging purpose. +func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) { + // Allocate fake storage if it's nil. + if s.fakeStorage == nil { + s.fakeStorage = make(Storage) + } + for key, value := range storage { + s.fakeStorage[key] = value + } + // Don't bother journal since this function should only be used for + // debugging and the `fake` storage won't be committed to database. +} + +func (s *stateObject) setState(key, value common.Hash) { + s.dirtyStorage[key] = value +} - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +// finalise moves all dirty storage slots into the pending area to be hashed or +// committed later. It is invoked at the end of every transaction. +func (s *stateObject) finalise() { + for key, value := range s.dirtyStorage { + s.pendingStorage[key] = value + } + if len(s.dirtyStorage) > 0 { + s.dirtyStorage = make(Storage) } } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db Database) Trie { - tr := self.getTrie(db) - for key, value := range self.dirtyStorage { - delete(self.dirtyStorage, key) - if (value == common.Hash{}) { - self.setError(tr.TryDelete(key[:])) +// It will return nil if the trie has not been loaded and no changes have been made +func (s *stateObject) updateTrie(db Database) Trie { + // Make sure all dirty slots are finalized into the pending storage area + s.finalise() + if len(s.pendingStorage) == 0 { + return s.trie + } + // Track the amount of time wasted on updating the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) + } + // Retrieve the snapshot storage map for the object + var storage map[common.Hash][]byte + if s.db.snap != nil { + // Retrieve the old storage map, if available, create a new one otherwise + storage = s.db.snapStorage[s.addrHash] + if storage == nil { + storage = make(map[common.Hash][]byte) + s.db.snapStorage[s.addrHash] = storage + } + } + // Insert all the pending updates into the trie + tr := s.getTrie(db) + for key, value := range s.pendingStorage { + // Skip noop changes, persist actual changes + if value == s.originStorage[key] { continue } - // Encoding []byte cannot fail, ok to ignore the error. - v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - self.setError(tr.TryUpdate(key[:], v)) + s.originStorage[key] = value + + var v []byte + if (value == common.Hash{}) { + s.setError(tr.DeleteStorage(s.address, key.Bytes())) + } else { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + s.setError(tr.UpdateStorage(s.address, key.Bytes(), v)) + } + // If state snapshotting is active, cache the data til commit + if storage != nil { + storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00 + } + } + if len(s.pendingStorage) > 0 { + s.pendingStorage = make(Storage) } return tr } // UpdateRoot sets the trie root to the current root hash of -func (self *stateObject) updateRoot(db Database) { - self.updateTrie(db) - self.data.Root = self.trie.Hash() +func (s *stateObject) updateRoot(db Database) { + // If nothing changed, don't bother with hashing anything + if s.updateTrie(db) == nil { + return + } + // Track the amount of time wasted on hashing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) + } + s.data.Root = s.trie.Hash() } -// CommitTrie the storage trie of the object to dwb. +// CommitTrie the storage trie of the object to db. // This updates the trie root. -func (self *stateObject) CommitTrie(db Database) error { - self.updateTrie(db) - if self.dbErr != nil { - return self.dbErr +func (s *stateObject) CommitTrie(db Database) error { + // If nothing changed, don't bother with hashing anything + if s.updateTrie(db) == nil { + return nil + } + if s.dbErr != nil { + return s.dbErr + } + // Track the amount of time wasted on committing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } - root, err := self.trie.Commit(nil) + root, err := s.trie.Commit(nil) if err == nil { - self.data.Root = root + s.data.Root = root } return err } // AddBalance removes amount from c's balance. // It is used to add funds to the destination account of a transfer. -func (c *stateObject) AddBalance(amount *big.Int) { +func (s *stateObject) AddBalance(amount *big.Int) { // EIP158: We must check emptiness for the objects such that the account // clearing (0,0,0 objects) can take effect. if amount.Sign() == 0 { - if c.empty() { - c.touch() + if s.empty() { + s.touch() } return } - c.SetBalance(new(big.Int).Add(c.Balance(), amount)) + s.SetBalance(new(big.Int).Add(s.Balance(), amount)) } // SubBalance removes amount from c's balance. // It is used to remove funds from the origin account of a transfer. -func (c *stateObject) SubBalance(amount *big.Int) { +func (s *stateObject) SubBalance(amount *big.Int) { if amount.Sign() == 0 { return } - c.SetBalance(new(big.Int).Sub(c.Balance(), amount)) + s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) } -func (self *stateObject) SetBalance(amount *big.Int) { - self.db.journal = append(self.db.journal, balanceChange{ - account: &self.address, - prev: new(big.Int).Set(self.data.Balance), +func (s *stateObject) SetBalance(amount *big.Int) { + s.db.journal.append(balanceChange{ + account: &s.address, + prev: new(big.Int).Set(s.data.Balance), }) - self.setBalance(amount) + s.setBalance(amount) } -func (self *stateObject) setBalance(amount *big.Int) { - self.data.Balance = amount - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } +func (s *stateObject) setBalance(amount *big.Int) { + s.data.Balance = amount } -// Return the gas back to the origin. Used by the Virtual machine or Closures -func (c *stateObject) ReturnGas(gas *big.Int) {} +// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures +func (s *stateObject) ReturnGas(gas *big.Int) {} -func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { - stateObject := newObject(db, self.address, self.data, onDirty) - if self.trie != nil { - stateObject.trie = db.db.CopyTrie(self.trie) +func (s *stateObject) deepCopy(db *StateDB) *stateObject { + stateObject := newObject(db, s.address, &s.data) + if s.trie != nil { + stateObject.trie = db.db.CopyTrie(s.trie) } - stateObject.code = self.code - stateObject.dirtyStorage = self.dirtyStorage.Copy() - stateObject.cachedStorage = self.dirtyStorage.Copy() - stateObject.suicided = self.suicided - stateObject.dirtyCode = self.dirtyCode - stateObject.deleted = self.deleted + stateObject.code = s.code + stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.originStorage = s.originStorage.Copy() + stateObject.pendingStorage = s.pendingStorage.Copy() + stateObject.suicided = s.suicided + stateObject.dirtyCode = s.dirtyCode + stateObject.deleted = s.deleted return stateObject } @@ -327,78 +428,70 @@ func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address) // Attribute accessors // -// Returns the address of the contract/account -func (c *stateObject) Address() common.Address { - return c.address +// Address returns the address of the contract/account +func (s *stateObject) Address() common.Address { + return s.address } // Code returns the contract code associated with this object, if any. -func (self *stateObject) Code(db Database) []byte { - if self.code != nil { - return self.code +func (s *stateObject) Code(db Database) []byte { + if s.code != nil { + return s.code } - if bytes.Equal(self.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), emptyCodeHash) { return nil } - code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash())) + code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash())) if err != nil { - self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) + s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } - self.code = code + s.code = code return code } -func (self *stateObject) SetCode(codeHash common.Hash, code []byte) { - prevcode := self.Code(self.db.db) - self.db.journal = append(self.db.journal, codeChange{ - account: &self.address, - prevhash: self.CodeHash(), +func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := s.Code(s.db.db) + s.db.journal.append(codeChange{ + account: &s.address, + prevhash: s.CodeHash(), prevcode: prevcode, }) - self.setCode(codeHash, code) + s.setCode(codeHash, code) } -func (self *stateObject) setCode(codeHash common.Hash, code []byte) { - self.code = code - self.data.CodeHash = codeHash[:] - self.dirtyCode = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } +func (s *stateObject) setCode(codeHash common.Hash, code []byte) { + s.code = code + s.data.CodeHash = codeHash[:] + s.dirtyCode = true } -func (self *stateObject) SetNonce(nonce uint64) { - self.db.journal = append(self.db.journal, nonceChange{ - account: &self.address, - prev: self.data.Nonce, +func (s *stateObject) SetNonce(nonce uint64) { + s.db.journal.append(nonceChange{ + account: &s.address, + prev: s.data.Nonce, }) - self.setNonce(nonce) + s.setNonce(nonce) } -func (self *stateObject) setNonce(nonce uint64) { - self.data.Nonce = nonce - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil - } +func (s *stateObject) setNonce(nonce uint64) { + s.data.Nonce = nonce } -func (self *stateObject) CodeHash() []byte { - return self.data.CodeHash +func (s *stateObject) CodeHash() []byte { + return s.data.CodeHash } -func (self *stateObject) Balance() *big.Int { - return self.data.Balance +func (s *stateObject) Balance() *big.Int { + return s.data.Balance } -func (self *stateObject) Nonce() uint64 { - return self.data.Nonce +func (s *stateObject) Nonce() uint64 { + return s.data.Nonce } -// Never called, but must be present to allow stateObject to be used +// Value is never called, but must be present to allow stateObject to be used // as a vm.Account interface that also satisfies the vm.ContractRef // interface. Interfaces are awesome. -func (self *stateObject) Value() *big.Int { +func (s *stateObject) Value() *big.Int { panic("Value on stateObject should never be called") } diff --git a/core/state/state_test.go b/core/state/state_test.go index 30cca6c361..17ecf3b192 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -18,14 +18,16 @@ package state import ( "bytes" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" + checker "gopkg.in/check.v1" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" - checker "gopkg.in/check.v1" + "github.com/tomochain/tomochain/trie" ) type StateSuite struct { @@ -88,8 +90,9 @@ func (s *StateSuite) TestDump(c *checker.C) { } func (s *StateSuite) SetUpTest(c *checker.C) { - s.db= rawdb.NewMemoryDatabase() - s.state, _ = New(common.Hash{}, NewDatabase(s.db)) + s.db = rawdb.NewMemoryDatabase() + tdb := NewDatabaseWithConfig(s.db, &trie.Config{Preimages: true}) + s.state, _ = New(common.Hash{}, tdb, nil) } func (s *StateSuite) TestNull(c *checker.C) { @@ -135,7 +138,7 @@ func (s *StateSuite) TestSnapshotEmpty(c *checker.C) { // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { db := rawdb.NewMemoryDatabase() - state, _ := New(common.Hash{}, NewDatabase(db)) + state, _ := New(common.Hash{}, NewDatabase(db), nil) stateobjaddr0 := toAddr([]byte("so0")) stateobjaddr1 := toAddr([]byte("so1")) @@ -210,24 +213,30 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) { t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) } - if len(so1.cachedStorage) != len(so0.cachedStorage) { - t.Errorf("Storage size mismatch: have %d, want %d", len(so1.cachedStorage), len(so0.cachedStorage)) + if len(so1.dirtyStorage) != len(so0.dirtyStorage) { + t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage)) } - for k, v := range so1.cachedStorage { - if so0.cachedStorage[k] != v { - t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.cachedStorage[k], v) + for k, v := range so1.dirtyStorage { + if so0.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v) } } - for k, v := range so0.cachedStorage { - if so1.cachedStorage[k] != v { - t.Errorf("Storage key %x mismatch: have %v, want none.", k, v) + for k, v := range so0.dirtyStorage { + if so1.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v) } } - - if so0.suicided != so1.suicided { - t.Fatalf("suicided mismatch: have %v, want %v", so0.suicided, so1.suicided) + if len(so1.originStorage) != len(so0.originStorage) { + t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage)) + } + for k, v := range so1.originStorage { + if so0.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v) + } } - if so0.deleted != so1.deleted { - t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) + for k, v := range so0.originStorage { + if so1.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v) + } } } diff --git a/core/state/statedb.go b/core/state/statedb.go index 7a3357b3e8..f264ede6da 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -22,11 +22,13 @@ import ( "math/big" "sort" "sync" + "time" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) @@ -37,6 +39,9 @@ type revision struct { } var ( + // emptyRoot is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + // emptyState is the known hash of an empty state trie entry. emptyState = crypto.Keccak256Hash(nil) @@ -44,7 +49,7 @@ var ( emptyCode = crypto.Keccak256Hash(nil) ) -// StateDBs within the ethereum protocol are used to store anything +// StateDB within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: // * Contracts @@ -53,6 +58,12 @@ type StateDB struct { db Database trie Trie + snaps *snapshot.Tree + snap snapshot.Snapshot + snapDestructs map[common.Hash]struct{} + snapAccounts map[common.Hash][]byte + snapStorage map[common.Hash]map[common.Hash][]byte + // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*stateObject stateObjectsDirty map[common.Address]struct{} @@ -76,144 +87,179 @@ type StateDB struct { // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. - journal journal + journal *journal validRevisions []revision nextRevisionId int + // Measurements gathered during execution for debugging purposes + AccountReads time.Duration + AccountHashes time.Duration + AccountUpdates time.Duration + AccountCommits time.Duration + StorageReads time.Duration + StorageHashes time.Duration + StorageUpdates time.Duration + StorageCommits time.Duration + SnapshotAccountReads time.Duration + SnapshotStorageReads time.Duration + SnapshotCommits time.Duration + lock sync.Mutex } -func (self *StateDB) SubRefund(gas uint64) { - self.journal = append(self.journal, refundChange{ - prev: self.refund}) - if gas > self.refund { - panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, self.refund)) +func (s *StateDB) SubRefund(gas uint64) { + s.journal.append(refundChange{ + prev: s.refund}) + if gas > s.refund { + panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund)) } - self.refund -= gas + s.refund -= gas } -func (self *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { + stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.GetCommittedState(self.db, hash) + return stateObject.GetCommittedState(s.db, hash) } return common.Hash{} } -// Create a new state from a given trie. -func New(root common.Hash, db Database) (*StateDB, error) { +// New creates a new state from a given trie. +func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err } - return &StateDB{ + sdb := &StateDB{ db: db, trie: tr, + snaps: snaps, stateObjects: make(map[common.Address]*stateObject), stateObjectsDirty: make(map[common.Address]struct{}), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), - }, nil + journal: newJournal(), + } + if sdb.snaps != nil { + sdb.snap = sdb.snaps.Snapshot(root) + } + if sdb.snaps != nil { + if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { + sdb.snapDestructs = make(map[common.Hash]struct{}) + sdb.snapAccounts = make(map[common.Hash][]byte) + sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } + } + return sdb, nil } // setError remembers the first non-nil error it is called with. -func (self *StateDB) setError(err error) { - if self.dbErr == nil { - self.dbErr = err +func (s *StateDB) setError(err error) { + if s.dbErr == nil { + s.dbErr = err } } -func (self *StateDB) Error() error { - return self.dbErr +func (s *StateDB) Error() error { + return s.dbErr } // Reset clears out all ephemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. -func (self *StateDB) Reset(root common.Hash) error { - tr, err := self.db.OpenTrie(root) +func (s *StateDB) Reset(root common.Hash) error { + tr, err := s.db.OpenTrie(root) if err != nil { return err } - self.trie = tr - self.stateObjects = make(map[common.Address]*stateObject) - self.stateObjectsDirty = make(map[common.Address]struct{}) - self.thash = common.Hash{} - self.bhash = common.Hash{} - self.txIndex = 0 - self.logs = make(map[common.Hash][]*types.Log) - self.logSize = 0 - self.preimages = make(map[common.Hash][]byte) - self.clearJournalAndRefund() + s.trie = tr + s.stateObjects = make(map[common.Address]*stateObject) + s.stateObjectsDirty = make(map[common.Address]struct{}) + s.thash = common.Hash{} + s.bhash = common.Hash{} + s.txIndex = 0 + s.logs = make(map[common.Hash][]*types.Log) + s.logSize = 0 + s.preimages = make(map[common.Hash][]byte) + s.clearJournalAndRefund() + + if s.snaps != nil { + s.snapAccounts, s.snapDestructs, s.snapStorage = nil, nil, nil + if s.snap = s.snaps.Snapshot(root); s.snap != nil { + s.snapDestructs = make(map[common.Hash]struct{}) + s.snapAccounts = make(map[common.Hash][]byte) + s.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } + } return nil } -func (self *StateDB) AddLog(log *types.Log) { - self.journal = append(self.journal, addLogChange{txhash: self.thash}) +func (s *StateDB) AddLog(log *types.Log) { + s.journal.append(addLogChange{txhash: s.thash}) - log.TxHash = self.thash - log.BlockHash = self.bhash - log.TxIndex = uint(self.txIndex) - log.Index = self.logSize - self.logs[self.thash] = append(self.logs[self.thash], log) - self.logSize++ + log.TxHash = s.thash + log.BlockHash = s.bhash + log.TxIndex = uint(s.txIndex) + log.Index = s.logSize + s.logs[s.thash] = append(s.logs[s.thash], log) + s.logSize++ } -func (self *StateDB) GetLogs(hash common.Hash) []*types.Log { - return self.logs[hash] +func (s *StateDB) GetLogs(hash common.Hash) []*types.Log { + return s.logs[hash] } -func (self *StateDB) Logs() []*types.Log { +func (s *StateDB) Logs() []*types.Log { var logs []*types.Log - for _, lgs := range self.logs { + for _, lgs := range s.logs { logs = append(logs, lgs...) } return logs } // AddPreimage records a SHA3 preimage seen by the VM. -func (self *StateDB) AddPreimage(hash common.Hash, preimage []byte) { - if _, ok := self.preimages[hash]; !ok { - self.journal = append(self.journal, addPreimageChange{hash: hash}) +func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) { + if _, ok := s.preimages[hash]; !ok { + s.journal.append(addPreimageChange{hash: hash}) pi := make([]byte, len(preimage)) copy(pi, preimage) - self.preimages[hash] = pi + s.preimages[hash] = pi } } // Preimages returns a list of SHA3 preimages that have been submitted. -func (self *StateDB) Preimages() map[common.Hash][]byte { - return self.preimages +func (s *StateDB) Preimages() map[common.Hash][]byte { + return s.preimages } -func (self *StateDB) AddRefund(gas uint64) { - self.journal = append(self.journal, refundChange{prev: self.refund}) - self.refund += gas +func (s *StateDB) AddRefund(gas uint64) { + s.journal.append(refundChange{prev: s.refund}) + s.refund += gas } // Exist reports whether the given account address exists in the state. // Notably this also returns true for suicided accounts. -func (self *StateDB) Exist(addr common.Address) bool { - return self.getStateObject(addr) != nil +func (s *StateDB) Exist(addr common.Address) bool { + return s.getStateObject(addr) != nil } // Empty returns whether the state object is either non-existent // or empty according to the EIP161 specification (balance = nonce = code = 0) -func (self *StateDB) Empty(addr common.Address) bool { - so := self.getStateObject(addr) +func (s *StateDB) Empty(addr common.Address) bool { + so := s.getStateObject(addr) return so == nil || so.empty() } -// Retrieve the balance from the given address or 0 if object not found -func (self *StateDB) GetBalance(addr common.Address) *big.Int { - stateObject := self.getStateObject(addr) +// GetBalance retrieves the balance from the given address or 0 if object not found +func (s *StateDB) GetBalance(addr common.Address) *big.Int { + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Balance() } return common.Big0 } -func (self *StateDB) GetNonce(addr common.Address) uint64 { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetNonce(addr common.Address) uint64 { + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.Nonce() } @@ -221,63 +267,63 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { return 0 } -func (self *StateDB) GetCode(addr common.Address) []byte { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCode(addr common.Address) []byte { + stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.Code(self.db) + return stateObject.Code(s.db) } return nil } -func (self *StateDB) GetCodeSize(addr common.Address) int { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCodeSize(addr common.Address) int { + stateObject := s.getStateObject(addr) if stateObject == nil { return 0 } if stateObject.code != nil { return len(stateObject.code) } - size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) + size, err := s.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash())) if err != nil { - self.setError(err) + s.setError(err) } return size } -func (self *StateDB) GetCodeHash(addr common.Address) common.Hash { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { + stateObject := s.getStateObject(addr) if stateObject == nil { return common.Hash{} } return common.BytesToHash(stateObject.CodeHash()) } -func (self *StateDB) GetState(addr common.Address, bhash common.Hash) common.Hash { - stateObject := self.getStateObject(addr) +func (s *StateDB) GetState(addr common.Address, bhash common.Hash) common.Hash { + stateObject := s.getStateObject(addr) if stateObject != nil { - return stateObject.GetState(self.db, bhash) + return stateObject.GetState(s.db, bhash) } return common.Hash{} } // Database retrieves the low level database supporting the lower level trie ops. -func (self *StateDB) Database() Database { - return self.db +func (s *StateDB) Database() Database { + return s.db } // StorageTrie returns the storage trie of an account. // The return value is a copy and is nil for non-existent accounts. -func (self *StateDB) StorageTrie(addr common.Address) Trie { - stateObject := self.getStateObject(addr) +func (s *StateDB) StorageTrie(addr common.Address) Trie { + stateObject := s.getStateObject(addr) if stateObject == nil { return nil } - cpy := stateObject.deepCopy(self, nil) - return cpy.updateTrie(self.db) + cpy := stateObject.deepCopy(s) + return cpy.updateTrie(s.db) } -func (self *StateDB) HasSuicided(addr common.Address) bool { - stateObject := self.getStateObject(addr) +func (s *StateDB) HasSuicided(addr common.Address) bool { + stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.suicided } @@ -289,46 +335,46 @@ func (self *StateDB) HasSuicided(addr common.Address) bool { */ // AddBalance adds amount to the account associated with addr. -func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.AddBalance(amount) } } // SubBalance subtracts amount from the account associated with addr. -func (self *StateDB) SubBalance(addr common.Address, amount *big.Int) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SubBalance(amount) } } -func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SetBalance(amount) } } -func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetNonce(addr common.Address, nonce uint64) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SetNonce(nonce) } } -func (self *StateDB) SetCode(addr common.Address, code []byte) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetCode(addr common.Address, code []byte) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { stateObject.SetCode(crypto.Keccak256Hash(code), code) } } -func (self *StateDB) SetState(addr common.Address, key, value common.Hash) { - stateObject := self.GetOrNewStateObject(addr) +func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { + stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(self.db, key, value) + stateObject.SetState(s.db, key, value) } } @@ -337,12 +383,12 @@ func (self *StateDB) SetState(addr common.Address, key, value common.Hash) { // // The account's state object is still available until the state is committed, // getStateObject will return a non-nil account after Suicide. -func (self *StateDB) Suicide(addr common.Address) bool { - stateObject := self.getStateObject(addr) +func (s *StateDB) Suicide(addr common.Address) bool { + stateObject := s.getStateObject(addr) if stateObject == nil { return false } - self.journal = append(self.journal, suicideChange{ + s.journal.append(suicideChange{ account: &addr, prev: stateObject.suicided, prevbalance: new(big.Int).Set(stateObject.Balance()), @@ -358,34 +404,43 @@ func (self *StateDB) Suicide(addr common.Address) bool { // // updateStateObject writes the given object to the trie. -func (self *StateDB) updateStateObject(stateObject *stateObject) { +func (s *StateDB) updateStateObject(stateObject *stateObject) { addr := stateObject.Address() - data, err := rlp.EncodeToBytes(stateObject) - if err != nil { - panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) + if err := s.trie.UpdateAccount(addr, &stateObject.data); err != nil { + s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) } - self.setError(self.trie.TryUpdate(addr[:], data)) + + // If state snapshotting is active, cache the data til commit. Note, this + // update mechanism is not symmetric to the deletion, because whereas it is + // enough to track account updates at commit time, deletions need tracking + // at transaction boundary level to ensure we capture state clearing. + if s.snap != nil { + s.snapAccounts[stateObject.addrHash] = types.SlimAccountRLP(stateObject.data) + } + } // deleteStateObject removes the given object from the state trie. -func (self *StateDB) deleteStateObject(stateObject *stateObject) { +func (s *StateDB) deleteStateObject(stateObject *stateObject) { stateObject.deleted = true addr := stateObject.Address() - self.setError(self.trie.TryDelete(addr[:])) + if err := s.trie.DeleteAccount(addr); err != nil { + s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) + } } // DeleteAddress removes the address from the state trie. -func (self *StateDB) DeleteAddress(addr common.Address) { - stateObject := self.getStateObject(addr) +func (s *StateDB) DeleteAddress(addr common.Address) { + stateObject := s.getStateObject(addr) if stateObject != nil && !stateObject.deleted { - self.deleteStateObject(stateObject) + s.deleteStateObject(stateObject) } } // Retrieve a state object given my the address. Returns nil if not found. -func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { +func (s *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { // Prefer 'live' objects. - if obj := self.stateObjects[addr]; obj != nil { + if obj := s.stateObjects[addr]; obj != nil { if obj.deleted { return nil } @@ -393,53 +448,95 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje } // Load the object from the database. - enc, err := self.trie.TryGet(addr[:]) - if len(enc) == 0 { - self.setError(err) + data, err := s.trie.GetAccount(addr) + if err != nil { + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) return nil } - var data Account - if err := rlp.DecodeBytes(enc, &data); err != nil { - log.Error("Failed to decode state object", "addr", addr, "err", err) + if data == nil { return nil } // Insert into the live set. - obj := newObject(self, addr, data, self.MarkStateObjectDirty) - self.setStateObject(obj) + obj := newObject(s, addr, data) + s.setStateObject(obj) + return obj +} + +// getDeletedStateObject is similar to getStateObject, but instead of returning +// nil for a deleted state object, it returns the actual object with the deleted +// flag set. This is needed by the state journal to revert to the correct s- +// destructed object instead of wiping all knowledge about the state object. +func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { + // Prefer live objects if any is available + if obj := s.stateObjects[addr]; obj != nil { + return obj + } + // If no live objects are available, attempt to use snapshots + var ( + data *types.StateAccount + err error + ) + if s.snap != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now()) + } + var acc *types.SlimAccount + if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil { + if acc == nil { + return nil + } + data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash + if len(data.CodeHash) == 0 { + data.CodeHash = emptyCodeHash + } + data.Root = common.BytesToHash(acc.Root) + if data.Root == (common.Hash{}) { + data.Root = emptyRoot + } + } + } + // If snapshot unavailable or reading from it failed, load from the database + if s.snap == nil || err != nil { + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now()) + } + data, err = s.trie.GetAccount(addr) + if err != nil { + s.setError(err) + return nil + } + } + // Insert into the live set + obj := newObject(s, addr, data) + s.setStateObject(obj) return obj } -func (self *StateDB) setStateObject(object *stateObject) { - self.stateObjects[object.Address()] = object +func (s *StateDB) setStateObject(object *stateObject) { + s.stateObjects[object.Address()] = object } -// Retrieve a state object or create a new state object if nil. -func (self *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { - stateObject := self.getStateObject(addr) +// GetOrNewStateObject retrieves a state object or create a new state object if nil. +func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { + stateObject := s.getStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject, _ = self.createObject(addr) + stateObject, _ = s.createObject(addr) } return stateObject } -// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly -// state object cache iteration to find a handful of modified ones. -func (self *StateDB) MarkStateObjectDirty(addr common.Address) { - self.stateObjectsDirty[addr] = struct{}{} -} - // createObject creates a new state object. If there is an existing account with // the given address, it is overwritten and returned as the second return value. -func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { - prev = self.getStateObject(addr) - newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) +func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { + prev = s.getStateObject(addr) + newobj = newObject(s, addr, &types.StateAccount{}) newobj.setNonce(0) // sets the object to dirty if prev == nil { - self.journal = append(self.journal, createObjectChange{account: &addr}) + s.journal.append(createObjectChange{account: &addr}) } else { - self.journal = append(self.journal, resetObjectChange{prev: prev}) + s.journal.append(resetObjectChange{prev: prev}) } - self.setStateObject(newobj) + s.setStateObject(newobj) return newobj, prev } @@ -449,34 +546,43 @@ func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObjec // CreateAccount is called during the EVM CREATE operation. The situation might arise that // a contract does the following: // -// 1. sends funds to sha(account ++ (nonce + 1)) -// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. -func (self *StateDB) CreateAccount(addr common.Address) { - new, prev := self.createObject(addr) +func (s *StateDB) CreateAccount(addr common.Address) { + new, prev := s.createObject(addr) if prev != nil { new.setBalance(prev.data.Balance) } } -func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { - so := db.getStateObject(addr) +func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { + so := s.getStateObject(addr) if so == nil { return nil } + tr := so.getTrie(s.db) + trieIt := tr.NodeIterator(nil) + it := trie.NewIterator(trieIt) - // When iterating over the storage check the cache first - for h, value := range so.cachedStorage { - cb(h, value) - } - - it := trie.NewIterator(so.getTrie(db.db).NodeIterator(nil)) for it.Next() { - // ignore cached values - key := common.BytesToHash(db.trie.GetKey(it.Key)) - if _, ok := so.cachedStorage[key]; !ok { - cb(key, common.BytesToHash(it.Value)) + key := common.BytesToHash(s.trie.GetKey(it.Key)) + if value, dirty := so.dirtyStorage[key]; dirty { + if !cb(key, value) { + return nil + } + continue + } + + if len(it.Value) > 0 { + _, content, _, err := rlp.Split(it.Value) + if err != nil { + return err + } + if !cb(key, common.BytesToHash(content)) { + return nil + } } } return nil @@ -484,81 +590,91 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common // Copy creates a deep, independent copy of the state. // Snapshots of the copied state cannot be applied to the copy. -func (self *StateDB) Copy() *StateDB { - self.lock.Lock() - defer self.lock.Unlock() +func (s *StateDB) Copy() *StateDB { + s.lock.Lock() + defer s.lock.Unlock() // Copy all the basic fields, initialize the memory ones state := &StateDB{ - db: self.db, - trie: self.db.CopyTrie(self.trie), - stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)), - stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), - refund: self.refund, - logs: make(map[common.Hash][]*types.Log, len(self.logs)), - logSize: self.logSize, + db: s.db, + trie: s.db.CopyTrie(s.trie), + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), + refund: s.refund, + logs: make(map[common.Hash][]*types.Log, len(s.logs)), + logSize: s.logSize, preimages: make(map[common.Hash][]byte), + journal: newJournal(), } // Copy the dirty states, logs, and preimages - for addr := range self.stateObjectsDirty { - state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) - state.stateObjectsDirty[addr] = struct{}{} + for addr := range s.journal.dirties { + // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), + // and in the Finalise-method, there is a case where an object is in the journal but not + // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for + // nil + if object, exist := s.stateObjects[addr]; exist { + // Even though the original object is dirty, we are not copying the journal, + // so we need to make sure that any side effect the journal would have caused + // during a commit (or similar op) is already applied to the copy. + state.stateObjects[addr] = object.deepCopy(state) + state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits + } } - for hash, logs := range self.logs { + for hash, logs := range s.logs { state.logs[hash] = make([]*types.Log, len(logs)) copy(state.logs[hash], logs) } - for hash, preimage := range self.preimages { + for hash, preimage := range s.preimages { state.preimages[hash] = preimage } return state } // Snapshot returns an identifier for the current revision of the state. -func (self *StateDB) Snapshot() int { - id := self.nextRevisionId - self.nextRevisionId++ - self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) +func (s *StateDB) Snapshot() int { + id := s.nextRevisionId + s.nextRevisionId++ + s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()}) return id } // RevertToSnapshot reverts all state changes made since the given revision. -func (self *StateDB) RevertToSnapshot(revid int) { +func (s *StateDB) RevertToSnapshot(revid int) { // Find the snapshot in the stack of valid snapshots. - idx := sort.Search(len(self.validRevisions), func(i int) bool { - return self.validRevisions[i].id >= revid + idx := sort.Search(len(s.validRevisions), func(i int) bool { + return s.validRevisions[i].id >= revid }) - if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { + if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid { panic(fmt.Errorf("revision id %v cannot be reverted", revid)) } - snapshot := self.validRevisions[idx].journalIndex - - // Replay the journal to undo changes. - for i := len(self.journal) - 1; i >= snapshot; i-- { - self.journal[i].undo(self) - } - self.journal = self.journal[:snapshot] + snapshot := s.validRevisions[idx].journalIndex - // Remove invalidated snapshots from the stack. - self.validRevisions = self.validRevisions[:idx] + // Replay the journal to undo changes and remove invalidated snapshots + s.journal.revert(s, snapshot) + s.validRevisions = s.validRevisions[:idx] } // GetRefund returns the current value of the refund counter. -func (self *StateDB) GetRefund() uint64 { - return self.refund +func (s *StateDB) GetRefund() uint64 { + return s.refund } // Finalise finalises the state by removing the self destructed objects // and clears the journal as well as the refunds. func (s *StateDB) Finalise(deleteEmptyObjects bool) { - for addr := range s.stateObjectsDirty { - stateObject := s.stateObjects[addr] + for addr := range s.journal.dirties { + stateObject, exist := s.stateObjects[addr] + if !exist { + continue + } + if stateObject.suicided || (deleteEmptyObjects && stateObject.empty()) { s.deleteStateObject(stateObject) } else { stateObject.updateRoot(s.db) s.updateStateObject(stateObject) } + s.stateObjectsDirty[addr] = struct{}{} } // Invalidate journal because reverting across transactions is not allowed. s.clearJournalAndRefund() @@ -574,10 +690,10 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { // Prepare sets the current transaction hash and index and block hash which is // used when the EVM emits new state logs. -func (self *StateDB) Prepare(thash, bhash common.Hash, ti int) { - self.thash = thash - self.bhash = bhash - self.txIndex = ti +func (s *StateDB) Prepare(thash, bhash common.Hash, ti int) { + s.thash = thash + s.bhash = bhash + s.txIndex = ti } // DeleteSuicides flags the suicided objects for deletion so that it @@ -602,7 +718,7 @@ func (s *StateDB) DeleteSuicides() { } func (s *StateDB) clearJournalAndRefund() { - s.journal = nil + s.journal = newJournal() s.validRevisions = s.validRevisions[:0] s.refund = 0 } @@ -611,6 +727,10 @@ func (s *StateDB) clearJournalAndRefund() { func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { defer s.clearJournalAndRefund() + for addr := range s.journal.dirties { + s.stateObjectsDirty[addr] = struct{}{} + } + // Commit objects to the trie. for addr, stateObject := range s.stateObjects { _, isDirty := s.stateObjectsDirty[addr] @@ -636,7 +756,7 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) } // Write trie changes. root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error { - var account Account + var account types.StateAccount if err := rlp.DecodeBytes(leaf, &account); err != nil { return nil } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 865ee073b8..085d8f7c27 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -20,7 +20,6 @@ import ( "bytes" "encoding/binary" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "math/rand" @@ -29,6 +28,8 @@ import ( "testing" "testing/quick" + "github.com/tomochain/tomochain/core/rawdb" + check "gopkg.in/check.v1" "github.com/tomochain/tomochain/common" @@ -40,7 +41,7 @@ import ( func TestUpdateLeaks(t *testing.T) { // Create an empty state database db := rawdb.NewMemoryDatabase() - state, _ := New(common.Hash{}, NewDatabase(db)) + state, _ := New(common.Hash{}, NewDatabase(db), nil) // Update it with some accounts for i := byte(0); i < 255; i++ { @@ -70,8 +71,8 @@ func TestIntermediateLeaks(t *testing.T) { // Create two state databases, one transitioning to the final state, the other final from the beginning transDb := rawdb.NewMemoryDatabase() finalDb := rawdb.NewMemoryDatabase() - transState, _ := New(common.Hash{}, NewDatabase(transDb)) - finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) + transState, _ := New(common.Hash{}, NewDatabase(transDb), nil) + finalState, _ := New(common.Hash{}, NewDatabase(finalDb), nil) modify := func(state *StateDB, addr common.Address, i, tweak byte) { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) @@ -129,7 +130,7 @@ func TestIntermediateLeaks(t *testing.T) { func TestCopy(t *testing.T) { // Create a random state test to copy and modify "independently" db := rawdb.NewMemoryDatabase() - orig, _ := New(common.Hash{}, NewDatabase(db)) + orig, _ := New(common.Hash{}, NewDatabase(db), nil) for i := byte(0); i < 255; i++ { obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) @@ -341,7 +342,7 @@ func (test *snapshotTest) run() bool { // Run all actions and create snapshots. var ( db = rawdb.NewMemoryDatabase() - state, _ = New(common.Hash{}, NewDatabase(db)) + state, _ = New(common.Hash{}, NewDatabase(db), nil) snapshotRevs = make([]int, len(test.snapshots)) sindex = 0 ) @@ -355,7 +356,7 @@ func (test *snapshotTest) run() bool { // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, state.Database()) + checkstate, _ := New(common.Hash{}, state.Database(), nil) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } @@ -415,15 +416,21 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { func (s *StateSuite) TestTouchDelete(c *check.C) { s.state.GetOrNewStateObject(common.Address{}) root, _ := s.state.Commit(false) - s.state.Reset(root) + s.state, _ = New(root, s.state.db, s.state.snaps) snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(big.Int)) - if len(s.state.stateObjectsDirty) != 1 { + if len(s.state.journal.dirties) != 1 { + c.Fatal("expected one dirty state object") + } + if s.state.journal.dirties[common.Address{}] != 1 { c.Fatal("expected one dirty state object") } s.state.RevertToSnapshot(snapshot) - if len(s.state.stateObjectsDirty) != 0 { + if len(s.state.journal.dirties) != 0 { + c.Fatal("expected no dirty state object") + } + if s.state.journal.dirties[common.Address{}] != 0 { c.Fatal("expected no dirty state object") } } diff --git a/core/state/sync.go b/core/state/sync.go index 95f29b2879..e26281c7db 100644 --- a/core/state/sync.go +++ b/core/state/sync.go @@ -20,6 +20,7 @@ import ( "bytes" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" @@ -29,7 +30,7 @@ import ( func NewStateSync(root common.Hash, database ethdb.KeyValueReader, bloom *trie.SyncBloom) *trie.Sync { var syncer *trie.Sync callback := func(leaf []byte, parent common.Hash) error { - var obj Account + var obj types.StateAccount if err := rlp.Decode(bytes.NewReader(leaf), &obj); err != nil { return err } diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 19fefb6548..69c6491f01 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -41,7 +41,7 @@ type testAccount struct { func makeTestState() (Database, common.Hash, []*testAccount) { // Create an empty state db := NewDatabase(rawdb.NewMemoryDatabase()) - state, _ := New(common.Hash{}, db) + state, _ := New(common.Hash{}, db, nil) // Fill it with some arbitrary data accounts := []*testAccount{} @@ -72,7 +72,7 @@ func makeTestState() (Database, common.Hash, []*testAccount) { // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { // Check root availability and state contents - state, err := New(root, NewDatabase(db)) + state, err := New(root, NewDatabase(db), nil) if err != nil { t.Fatalf("failed to create state trie at %x: %v", root, err) } @@ -113,7 +113,7 @@ func checkStateConsistency(db ethdb.Database, root common.Hash) error { if _, err := db.Get(root.Bytes()); err != nil { return nil // Consider a non existent state consistent. } - state, err := New(root, NewDatabase(db)) + state, err := New(root, NewDatabase(db), nil) if err != nil { return err } diff --git a/core/state_processor.go b/core/state_processor.go index 035c15f2b3..b0aeb04b3d 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -18,9 +18,6 @@ package core import ( "fmt" - - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/log" "math/big" "runtime" "strings" @@ -33,7 +30,9 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/tomox/tradingstate" ) // StateProcessor is a basic Processor, which takes care of transitioning @@ -243,7 +242,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* balanceFee = value } } - msg, err := tx.AsMessage(types.MakeSigner(config, header.Number), balanceFee, header.Number) + msg, err := TransactionToMessage(tx, types.MakeSigner(config, header.Number), balanceFee, header.Number) if err != nil { return nil, 0, err, false } @@ -391,7 +390,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* blockMap[9147453] = "0x3538a544021c07869c16b764424c5987409cba48" blockMap[9147459] = "0xe187cf86c2274b1f16e8225a7da9a75aba4f1f5f" - addrFrom := msg.From().Hex() + addrFrom := msg.From.Hex() currentBlockNumber := header.Number.Int64() if addr, ok := blockMap[currentBlockNumber]; ok { @@ -408,7 +407,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* // End Bypass blacklist address // Apply the transaction to the current state (included in the env) - _, gas, failed, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner) + result, err := ApplyMessage(vmenv, msg, gp, coinbaseOwner) if err != nil { return nil, 0, err, false @@ -420,24 +419,24 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* } else { root = statedb.IntermediateRoot(config.IsEIP158(header.Number)).Bytes() } - *usedGas += gas + *usedGas += result.UsedGas // Create a new receipt for the transaction, storing the intermediate root and gas used by the tx // based on the eip phase, we're passing wether the root touch-delete accounts. - receipt := types.NewReceipt(root, failed, *usedGas) + receipt := types.NewReceipt(root, result.Failed(), *usedGas) receipt.TxHash = tx.Hash() - receipt.GasUsed = gas + receipt.GasUsed = result.UsedGas // if the transaction created a contract, store the creation address in the receipt. - if msg.To() == nil { + if msg.To == nil { receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce()) } // Set the receipt logs and create a bloom for filtering receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) - if balanceFee != nil && failed { - state.PayFeeWithTRC21TxFail(statedb, msg.From(), *tx.To()) + if balanceFee != nil && result.Failed() { + state.PayFeeWithTRC21TxFail(statedb, msg.From, *tx.To()) } - return receipt, gas, err, balanceFee != nil + return receipt, result.UsedGas, err, balanceFee != nil } func ApplySignTransaction(config *params.ChainConfig, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64) (*types.Receipt, uint64, error, bool) { diff --git a/core/state_transition.go b/core/state_transition.go index 9a2b079249..cf95a43a4b 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -18,12 +18,13 @@ package core import ( "errors" + "fmt" "math" "math/big" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" - "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" ) @@ -42,15 +43,17 @@ The state transitioning model does all all the necessary work to work out a vali 3) Create a new state object if the recipient is \0*32 4) Value transfer == If contract creation == - 4a) Attempt to run transaction data - 4b) If valid, use result as code for the new state object + + 4a) Attempt to run transaction data + 4b) If valid, use result as code for the new state object + == end == 5) Run Script section 6) Derive new state root */ type StateTransition struct { gp *GasPool - msg Message + msg *Message gas uint64 gasPrice *big.Int initialGas uint64 @@ -60,20 +63,58 @@ type StateTransition struct { evm *vm.EVM } -// Message represents a message sent to a contract. -type Message interface { - From() common.Address - //FromFrontier() (common.Address, error) - To() *common.Address +// A Message contains the data derived from a single transaction that is relevant to state +// processing. +type Message struct { + To *common.Address + From common.Address + Nonce uint64 + Value *big.Int + GasLimit uint64 + GasPrice *big.Int + Data []byte + PmAddress common.Address + PmPayload []byte + BalanceTokenFee *big.Int + + // When SkipAccountChecks is true, the message nonce is not checked against the + // account nonce in state. It also disables checking that the sender is an EOA. + // This field will be set to true for operations like RPC eth_call. + SkipAccountChecks bool +} + +// message no matter the execution itself is successful or not. +type ExecutionResult struct { + UsedGas uint64 // Total used gas but include the refunded gas + Err error // Any error encountered during the execution(listed in core/vm/errors.go) + ReturnData []byte // Returned data from evm(function result or data supplied with revert opcode) +} - GasPrice() *big.Int - Gas() uint64 - Value() *big.Int +// Unwrap returns the internal evm error which allows us for further +// analysis outside. +func (result *ExecutionResult) Unwrap() error { + return result.Err +} + +// Failed returns the indicator whether the execution is successful or not +func (result *ExecutionResult) Failed() bool { return result.Err != nil } + +// Return is a helper function to help caller distinguish between revert reason +// and function return. Return returns the data after execution if no error occurs. +func (result *ExecutionResult) Return() []byte { + if result.Err != nil { + return nil + } + return common.CopyBytes(result.ReturnData) +} - Nonce() uint64 - CheckNonce() bool - Data() []byte - BalanceTokenFee() *big.Int +// Revert returns the concrete revert reason if the execution is aborted by `REVERT` +// opcode. Note the reason can be nil if no data supplied with revert opcode. +func (result *ExecutionResult) Revert() []byte { + if result.Err != vm.ErrExecutionReverted { + return nil + } + return common.CopyBytes(result.ReturnData) } // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. @@ -96,13 +137,13 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) } // Make sure we don't exceed uint64 for all data combinations if (math.MaxUint64-gas)/params.TxDataNonZeroGas < nz { - return 0, vm.ErrOutOfGas + return 0, ErrGasUintOverflow } gas += nz * params.TxDataNonZeroGas z := uint64(len(data)) - nz if (math.MaxUint64-gas)/params.TxDataZeroGas < z { - return 0, vm.ErrOutOfGas + return 0, ErrGasUintOverflow } gas += z * params.TxDataZeroGas } @@ -110,18 +151,47 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) } // NewStateTransition initialises and returns a new state transition object. -func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition { +func NewStateTransition(evm *vm.EVM, msg *Message, gp *GasPool) *StateTransition { return &StateTransition{ gp: gp, evm: evm, msg: msg, - gasPrice: msg.GasPrice(), - value: msg.Value(), - data: msg.Data(), + gasPrice: msg.GasPrice, + value: msg.Value, + data: msg.Data, state: evm.StateDB, } } +// TransactionToMessage converts a transaction into a Message. +func TransactionToMessage(tx *types.Transaction, s types.Signer, balanceFee *big.Int, number *big.Int) (*Message, error) { + from, err := types.Sender(s, tx) + msg := &Message{ + From: from, + Nonce: tx.Nonce(), + GasLimit: tx.Gas(), + GasPrice: new(big.Int).Set(tx.GasPrice()), + To: tx.To(), + Value: tx.Value(), + Data: tx.Data(), + PmAddress: from, + PmPayload: tx.PmPayload(), + SkipAccountChecks: false, + BalanceTokenFee: balanceFee, + } + if len(msg.PmPayload) >= 20 { + msg.PmAddress = common.BytesToAddress(msg.PmPayload[:20]) // the first 20 bytes of PmPayload is the address of Paymaster contract + } + if balanceFee != nil { + if number.Cmp(common.TIPTRC21Fee) > 0 { + msg.GasPrice = common.TRC21GasPrice + } else { + msg.GasPrice = common.TRC21GasPriceBefore + } + } + return msg, err +} + // ApplyMessage computes the new state by applying the given message // against the old state within the environment. // @@ -129,12 +199,20 @@ func NewStateTransition(evm *vm.EVM, msg Message, gp *GasPool) *StateTransition // the gas used (which includes gas refunds) and an error if it failed. An error always // indicates a core error meaning that the message would always fail for that particular // state and would never be accepted within a block. -func ApplyMessage(evm *vm.EVM, msg Message, gp *GasPool, owner common.Address) ([]byte, uint64, bool, error) { +func ApplyMessage(evm *vm.EVM, msg *Message, gp *GasPool, owner common.Address) (*ExecutionResult, error) { return NewStateTransition(evm, msg, gp).TransitionDb(owner) } func (st *StateTransition) from() vm.AccountRef { - f := st.msg.From() + f := st.msg.From + if !st.state.Exist(f) { + st.state.CreateAccount(f) + } + return vm.AccountRef(f) +} + +func (st *StateTransition) pmAddress() vm.AccountRef { + f := st.msg.PmAddress if !st.state.Exist(f) { st.state.CreateAccount(f) } @@ -142,14 +220,14 @@ func (st *StateTransition) from() vm.AccountRef { } func (st *StateTransition) balanceTokenFee() *big.Int { - return st.msg.BalanceTokenFee() + return st.msg.BalanceTokenFee } func (st *StateTransition) to() vm.AccountRef { if st.msg == nil { return vm.AccountRef{} } - to := st.msg.To() + to := st.msg.To if to == nil { return vm.AccountRef{} // contract creation } @@ -161,22 +239,13 @@ func (st *StateTransition) to() vm.AccountRef { return reference } -func (st *StateTransition) useGas(amount uint64) error { - if st.gas < amount { - return vm.ErrOutOfGas - } - st.gas -= amount - - return nil -} - func (st *StateTransition) buyGas() error { var ( state = st.state balanceTokenFee = st.balanceTokenFee() - from = st.from() + from = st.pmAddress() ) - mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.Gas()), st.gasPrice) + mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.GasLimit), st.gasPrice) if balanceTokenFee == nil { if state.GetBalance(from.Address()).Cmp(mgval) < 0 { return errInsufficientBalanceForGas @@ -184,12 +253,12 @@ func (st *StateTransition) buyGas() error { } else if balanceTokenFee.Cmp(mgval) < 0 { return errInsufficientBalanceForGas } - if err := st.gp.SubGas(st.msg.Gas()); err != nil { + if err := st.gp.SubGas(st.msg.GasLimit); err != nil { return err } - st.gas += st.msg.Gas() + st.gas += st.msg.GasLimit - st.initialGas = st.msg.Gas() + st.initialGas = st.msg.GasLimit if balanceTokenFee == nil { state.SubBalance(from.Address(), mgval) } @@ -197,72 +266,95 @@ func (st *StateTransition) buyGas() error { } func (st *StateTransition) preCheck() error { + // Only check transactions that are not fake msg := st.msg - sender := st.from() - - // Make sure this transaction's nonce is correct - if msg.CheckNonce() { - nonce := st.state.GetNonce(sender.Address()) - if nonce < msg.Nonce() { - return ErrNonceTooHigh - } else if nonce > msg.Nonce() { - return ErrNonceTooLow + if !msg.SkipAccountChecks { + // Make sure this transaction's nonce is correct. + stNonce := st.state.GetNonce(msg.From) + if msgNonce := msg.Nonce; stNonce < msgNonce { + return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooHigh, + msg.From.Hex(), msgNonce, stNonce) + } else if stNonce > msgNonce { + return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooLow, + msg.From.Hex(), msgNonce, stNonce) + } else if stNonce+1 < stNonce { + return fmt.Errorf("%w: address %v, nonce: %d", ErrNonceMax, + msg.From.Hex(), stNonce) + } + // Make sure the sender is an EOA + codeHash := st.state.GetCodeHash(msg.From) + if codeHash != (common.Hash{}) && codeHash != types.EmptyCodeHash { + return fmt.Errorf("%w: address %v, codehash: %s", ErrSenderNoEOA, + msg.From.Hex(), codeHash) } } + return st.buyGas() } // TransitionDb will transition the state by applying the current message and -// returning the result including the the used gas. It returns an error if it -// failed. An error indicates a consensus issue. -func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedGas uint64, failed bool, err error) { - if err = st.preCheck(); err != nil { - return +// returning the evm execution result with following fields. +// +// - used gas: +// total gas used (including gas being refunded) +// - returndata: +// the returned data from evm +// - concrete execution error: +// various **EVM** error which aborts the execution, +// e.g. ErrOutOfGas, ErrExecutionReverted +// +// However if any consensus issue encountered, return the error directly with +// nil evm execution result. +func (st *StateTransition) TransitionDb(owner common.Address) (*ExecutionResult, error) { + // First check this message satisfies all consensus rules before + // applying the message. The rules include these clauses + // + // 1. the nonce of the message caller is correct + // 2. caller has enough balance to cover transaction fee(gaslimit * gasprice) + // 3. the amount of gas required is available in the block + // 4. the purchased gas is enough to cover intrinsic usage + // 5. there is no overflow when calculating intrinsic gas + // 6. caller has enough balance to cover asset transfer for **topmost** call + + // Check clauses 1-3, buy gas if everything is correct + if err := st.preCheck(); err != nil { + return nil, err } msg := st.msg sender := st.from() // err checked in preCheck homestead := st.evm.ChainConfig().IsHomestead(st.evm.BlockNumber) - contractCreation := msg.To() == nil + contractCreation := msg.To == nil - // Pay intrinsic gas + // Check clauses 4-5, substract intrinsic gas if everything is correct gas, err := IntrinsicGas(st.data, contractCreation, homestead) if err != nil { - return nil, 0, false, err + return nil, err } - if err = st.useGas(gas); err != nil { - return nil, 0, false, err + if st.gas < gas { + return nil, ErrIntrinsicGas + } + st.gas -= gas + + // check clause 6 + if msg.Value.Sign() > 0 && !st.evm.CanTransfer(st.state, msg.From, msg.Value) { + return nil, ErrInsufficientFundsForTransfer } var ( - evm = st.evm - // vm errors do not effect consensus and are therefor - // not assigned to err, except for insufficient balance - // error. + ret []byte vmerr error ) // for debugging purpose // TODO: clean it after fixing the issue https://github.com/tomochain/tomochain/issues/401 - var contractAction string nonce := uint64(1) if contractCreation { - ret, _, st.gas, vmerr = evm.Create(sender, st.data, st.gas, st.value) - contractAction = "contract creation" + ret, _, st.gas, vmerr = st.evm.Create(sender, st.data, st.gas, st.value) } else { // Increment the nonce for the next transaction nonce = st.state.GetNonce(sender.Address()) + 1 st.state.SetNonce(sender.Address(), nonce) - ret, st.gas, vmerr = evm.Call(sender, st.to().Address(), st.data, st.gas, st.value) - contractAction = "contract call" - } - if vmerr != nil { - log.Debug("VM returned with error", "action", contractAction, "contract address", st.to().Address(), "gas", st.gas, "gasPrice", st.gasPrice, "nonce", nonce, "err", vmerr) - // The only possible consensus-error would be if there wasn't - // sufficient balance to make the transfer happen. The first - // balance transfer may never fail. - if vmerr == vm.ErrInsufficientBalance { - return nil, 0, false, vmerr - } + ret, st.gas, vmerr = st.evm.Call(sender, st.to().Address(), st.data, st.gas, st.value) } st.refundGas() @@ -274,7 +366,11 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG st.state.AddBalance(st.evm.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice)) } - return ret, st.gasUsed(), vmerr != nil, err + return &ExecutionResult{ + UsedGas: st.gasUsed(), + Err: vmerr, + ReturnData: ret, + }, err } func (st *StateTransition) refundGas() { @@ -287,7 +383,7 @@ func (st *StateTransition) refundGas() { balanceTokenFee := st.balanceTokenFee() if balanceTokenFee == nil { - from := st.from() + from := st.pmAddress() // Return ETH for remaining gas, exchanged at the original rate. remaining := new(big.Int).Mul(new(big.Int).SetUint64(st.gas), st.gasPrice) st.state.AddBalance(from.Address(), remaining) diff --git a/core/token_validator.go b/core/token_validator.go index 485ff05c59..e639e8a0a0 100644 --- a/core/token_validator.go +++ b/core/token_validator.go @@ -17,7 +17,11 @@ package core import ( "fmt" - ethereum "github.com/tomochain/tomochain" + "math/big" + "math/rand" + "strings" + + tomochain "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" @@ -25,9 +29,6 @@ import ( "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/log" - "math/big" - "math/rand" - "strings" ) const ( @@ -38,7 +39,7 @@ const ( // callmsg implements core.Message to allow passing it as a transaction simulator. type callmsg struct { - ethereum.CallMsg + tomochain.CallMsg } func (m callmsg) From() common.Address { return m.CallMsg.From } @@ -52,7 +53,7 @@ func (m callmsg) Data() []byte { return m.CallMsg.Data } func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } type SimulatedBackend interface { - CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) + CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) } // GetTokenAbi return token abi @@ -72,7 +73,7 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA } fakeCaller := common.HexToAddress("0x0000000000000000000000000000000000000001") statedb.SetBalance(fakeCaller, common.BasePrice) - msg := ethereum.CallMsg{To: &contractAddr, Data: input, From: fakeCaller} + msg := tomochain.CallMsg{To: &contractAddr, Data: input, From: fakeCaller} result, err := CallContractWithState(msg, chain, statedb) if err != nil { return nil, err @@ -85,9 +86,9 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA return unpackResult, nil } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. -func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { +func CallContractWithState(call tomochain.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. call.GasPrice = big.NewInt(0) @@ -98,11 +99,19 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, call.Value = new(big.Int) } // Execute the call. - msg := callmsg{call} + msg := &Message{ + To: call.To, + From: call.From, + Value: call.Value, + GasLimit: call.Gas, + GasPrice: call.GasPrice, + Data: call.Data, + SkipAccountChecks: false, + } feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) - if msg.To() != nil { - if value, ok := feeCapacity[*msg.To()]; ok { - msg.CallMsg.BalanceTokenFee = value + if msg.To != nil { + if value, ok := feeCapacity[*msg.To]; ok { + msg.BalanceTokenFee = value } } evmContext := NewEVMContext(msg, chain.CurrentHeader(), chain, nil) @@ -111,11 +120,11 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, vmenv := vm.NewEVM(evmContext, statedb, nil, chain.Config(), vm.Config{}) gaspool := new(GasPool).AddGas(1000000) owner := common.Address{} - rval, _, _, err := NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) + result, err := NewStateTransition(vmenv, msg, gaspool).TransitionDb(owner) if err != nil { return nil, err } - return rval, err + return result.Return(), err } // make sure that balance of token is at slot 0 diff --git a/core/tx_pool.go b/core/tx_pool.go index a1028ccaa3..88418b07cd 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -19,21 +19,22 @@ package core import ( "errors" "fmt" - "github.com/tomochain/tomochain/consensus" "math" "math/big" "sort" "sync" "time" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" "github.com/tomochain/tomochain/params" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -80,13 +81,15 @@ var ( // making the transaction invalid, rather a DOS protection. ErrOversizedData = errors.New("oversized data") - ErrZeroGasPrice = errors.New("zero gas price") - - ErrUnderMinGasPrice = errors.New("under min gas price") + // ErrTxTypeNotSupported is returned if a transaction is not supported in the + // current network configuration. + ErrTxTypeNotSupported = types.ErrTxTypeNotSupported + ErrZeroGasPrice = errors.New("zero gas price") + ErrUnderMinGasPrice = errors.New("under min gas price") ErrDuplicateSpecialTransaction = errors.New("duplicate a special transaction") - - ErrMinDeploySMC = errors.New("smart contract creation cost is under allowance") + ErrMinDeploySMC = errors.New("smart contract creation cost is under allowance") + ErrPmPayloadTooShort = errors.New("the length of the paymaster payload is under 20 bytes minumum") ) var ( @@ -583,113 +586,6 @@ func (pool *TxPool) GetSender(tx *types.Transaction) (common.Address, error) { return from, nil } -// validateTx checks whether a transaction is valid according to the consensus -// rules and adheres to some heuristic limits of the local node (price and size). -func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { - // check if sender is in black list - if tx.From() != nil && common.Blacklist[*tx.From()] { - return fmt.Errorf("Reject transaction with sender in black-list: %v", tx.From().Hex()) - } - // check if receiver is in black list - if tx.To() != nil && common.Blacklist[*tx.To()] { - return fmt.Errorf("Reject transaction with receiver in black-list: %v", tx.To().Hex()) - } - - // Heuristic limit, reject transactions over 32KB to prevent DOS attacks - if tx.Size() > 32*1024 { - return ErrOversizedData - } - // Transactions can't be negative. This may never happen using RLP decoded - // transactions but may occur if you create a transaction using the RPC. - if tx.Value().Sign() < 0 { - return ErrNegativeValue - } - // Ensure the transaction doesn't exceed the current block limit gas. - if pool.currentMaxGas < tx.Gas() { - return ErrGasLimit - } - // Make sure the transaction is signed properly - from, err := types.Sender(pool.signer, tx) - if err != nil { - return ErrInvalidSender - } - // Drop non-local transactions under our own minimal accepted gas price - local = local || pool.locals.contains(from) // account may be local even if the transaction arrived from the network - if !local && pool.gasPrice.Cmp(tx.GasPrice()) > 0 { - if !tx.IsSpecialTransaction() || (pool.IsSigner != nil && !pool.IsSigner(from)) { - return ErrUnderpriced - } - } - // Ensure the transaction adheres to nonce ordering - if pool.currentState.GetNonce(from) > tx.Nonce() { - return ErrNonceTooLow - } - if pool.pendingState.GetNonce(from)+common.LimitThresholdNonceInQueue < tx.Nonce() { - return ErrNonceTooHigh - } - // Transactor should have enough funds to cover the costs - // cost == V + GP * GL - balance := pool.currentState.GetBalance(from) - cost := tx.Cost() - minGasPrice := common.MinGasPrice - feeCapacity := big.NewInt(0) - - if tx.To() != nil { - if value, ok := pool.trc21FeeCapacity[*tx.To()]; ok { - feeCapacity = value - if !state.ValidateTRC21Tx(pool.pendingState.StateDB, from, *tx.To(), tx.Data()) { - return ErrInsufficientFunds - } - cost = tx.TRC21Cost() - minGasPrice = common.TRC21GasPrice - } - } - if new(big.Int).Add(balance, feeCapacity).Cmp(cost) < 0 { - return ErrInsufficientFunds - } - - if tx.To() == nil || (tx.To() != nil && !tx.IsSpecialTransaction()) { - intrGas, err := IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead) - if err != nil { - return err - } - // Exclude check smart contract sign address. - if tx.Gas() < intrGas { - return ErrIntrinsicGas - } - - // Check zero gas price. - if tx.GasPrice().Cmp(new(big.Int).SetInt64(0)) == 0 { - return ErrZeroGasPrice - } - - // under min gas price - if tx.GasPrice().Cmp(minGasPrice) < 0 { - return ErrUnderMinGasPrice - } - } - - /* - minGasDeploySMC := new(big.Int).Mul(new(big.Int).SetUint64(10), new(big.Int).SetUint64(params.Ether)) - if tx.To() == nil && (tx.Cost().Cmp(minGasDeploySMC) < 0 || tx.GasPrice().Cmp(new(big.Int).SetUint64(10000*params.Shannon)) < 0) { - return ErrMinDeploySMC - } - */ - - // validate minFee slot for TomoZ - if tx.IsTomoZApplyTransaction() { - copyState := pool.currentState.Copy() - return ValidateTomoZApplyTransaction(pool.chain, nil, copyState, common.BytesToAddress(tx.Data()[4:])) - } - - // validate balance slot, token decimal for TomoX - if tx.IsTomoXApplyTransaction() { - copyState := pool.currentState.Copy() - return ValidateTomoXApplyTransaction(pool.chain, nil, copyState, common.BytesToAddress(tx.Data()[4:])) - } - return nil -} - // add validates a transaction and inserts it into the non-executable queue for // later pending promotion and execution. If the transaction is a replacement for // an already pending or queued one, it overwrites the previous and returns this @@ -706,8 +602,14 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { return false, fmt.Errorf("known transaction: %x", hash) } + opts := &ValidationOptions{ + Config: pool.chainconfig, + Accept: 0 | + 1< 32*1024 { + return ErrOversizedData + } + // Transactions can't be negative. This may never happen using RLP decoded + // transactions but may occur if you create a transaction using the RPC. + if tx.Value().Sign() < 0 { + return ErrNegativeValue + } + // Ensure the transaction doesn't exceed the current block limit gas. + if pool.currentMaxGas < tx.Gas() { + return ErrGasLimit + } + // Make sure the transaction is signed properly + from, err := types.Sender(pool.signer, tx) + if err != nil { + return ErrInvalidSender + } + // Drop non-local transactions under our own minimal accepted gas price + local = local || pool.locals.contains(from) // account may be local even if the transaction arrived from the network + if !local && pool.gasPrice.Cmp(tx.GasPrice()) > 0 { + if !tx.IsSpecialTransaction() || (pool.IsSigner != nil && !pool.IsSigner(from)) { + return ErrUnderpriced + } + } + // Ensure the transaction adheres to nonce ordering + if pool.currentState.GetNonce(from) > tx.Nonce() { + return ErrNonceTooLow + } + if pool.pendingState.GetNonce(from)+common.LimitThresholdNonceInQueue < tx.Nonce() { + return ErrNonceTooHigh + } + // Transactor should have enough funds to cover the costs + // cost == V + GP * GL + balance := pool.currentState.GetBalance(from) + cost := tx.Cost() + minGasPrice := common.MinGasPrice + feeCapacity := big.NewInt(0) + + if tx.To() != nil { + if value, ok := pool.trc21FeeCapacity[*tx.To()]; ok { + feeCapacity = value + if !state.ValidateTRC21Tx(pool.pendingState.StateDB, from, *tx.To(), tx.Data()) { + return ErrInsufficientFunds + } + cost = tx.TRC21Cost() + minGasPrice = common.TRC21GasPrice + } + } + if new(big.Int).Add(balance, feeCapacity).Cmp(cost) < 0 { + return ErrInsufficientFunds + } + + if tx.To() == nil || (tx.To() != nil && !tx.IsSpecialTransaction()) { + intrGas, err := IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead) + if err != nil { + return err + } + // Exclude check smart contract sign address. + if tx.Gas() < intrGas { + return ErrIntrinsicGas + } + + // Check zero gas price. + if tx.GasPrice().Cmp(new(big.Int).SetInt64(0)) == 0 { + return ErrZeroGasPrice + } + + // under min gas price + if tx.GasPrice().Cmp(minGasPrice) < 0 { + return ErrUnderMinGasPrice + } + } + + /* + minGasDeploySMC := new(big.Int).Mul(new(big.Int).SetUint64(10), new(big.Int).SetUint64(params.Ether)) + if tx.To() == nil && (tx.Cost().Cmp(minGasDeploySMC) < 0 || tx.GasPrice().Cmp(new(big.Int).SetUint64(10000*params.Shannon)) < 0) { + return ErrMinDeploySMC + } + */ + + // validate minFee slot for TomoZ + if tx.IsTomoZApplyTransaction() { + copyState := pool.currentState.Copy() + return ValidateTomoZApplyTransaction(pool.chain, nil, copyState, common.BytesToAddress(tx.Data()[4:])) + } + + // validate balance slot, token decimal for TomoX + if tx.IsTomoXApplyTransaction() { + copyState := pool.currentState.Copy() + return ValidateTomoXApplyTransaction(pool.chain, nil, copyState, common.BytesToAddress(tx.Data()[4:])) + } + + // validate the length of paymaster payload + if tx.Type() == types.PaymasterTxType && len(tx.PmPayload()) < 20 { + return ErrPmPayloadTooShort + } + return nil +} diff --git a/core/types/block.go b/core/types/block.go index a055ced147..66baecf2cf 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -33,11 +33,6 @@ import ( "github.com/tomochain/tomochain/rlp" ) -var ( - EmptyRootHash = DeriveSha(Transactions{}) - EmptyUncleHash = CalcUncleHash(nil) -) - // A BlockNonce is a 64-bit hash which proves (combined with the // mix-hash) that a sufficient amount of computation has been carried // out on a block. @@ -225,14 +220,14 @@ type storageblock struct { // The values of TxHash, UncleHash, ReceiptHash and Bloom in header // are ignored and set to values derived from the given txs, uncles // and receipts. -func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt) *Block { +func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []*Receipt, hasher Hasher) *Block { b := &Block{header: CopyHeader(header), td: new(big.Int)} // TODO: panic if len(txs) != len(receipts) if len(txs) == 0 { b.header.TxHash = EmptyRootHash } else { - b.header.TxHash = DeriveSha(Transactions(txs)) + b.header.TxHash = DeriveSha(Transactions(txs), hasher) b.transactions = make(Transactions, len(txs)) copy(b.transactions, txs) } @@ -240,7 +235,7 @@ func NewBlock(header *Header, txs []*Transaction, uncles []*Header, receipts []* if len(receipts) == 0 { b.header.ReceiptHash = EmptyRootHash } else { - b.header.ReceiptHash = DeriveSha(Receipts(receipts)) + b.header.ReceiptHash = DeriveSha(Receipts(receipts), hasher) b.header.Bloom = CreateBloom(receipts) } diff --git a/core/types/block_test.go b/core/types/block_test.go index 9b78b653c7..e93ae02de8 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -17,13 +17,15 @@ package types import ( + "bytes" + "hash" "math/big" + "reflect" "testing" - "bytes" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" - "reflect" + "golang.org/x/crypto/sha3" ) // from bcValidBlockTest.json, "SimpleTx" @@ -59,3 +61,38 @@ func TestBlockEncoding(t *testing.T) { t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) } } + +func TestUncleHash(t *testing.T) { + uncles := make([]*Header, 0) + h := CalcUncleHash(uncles) + exp := common.HexToHash("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") + if h != exp { + t.Fatalf("empty uncle hash is wrong, got %x != %x", h, exp) + } +} + +var benchBuffer = bytes.NewBuffer(make([]byte, 0, 32000)) + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +func newHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +func (h *testHasher) Update(key, val []byte) { + h.hasher.Write(key) + h.hasher.Write(val) +} + +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 2731c39cbb..210ee26e06 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -21,21 +21,58 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/rlp" - "github.com/tomochain/tomochain/trie" ) +// DerivableList is the interface which can derive the hash. type DerivableList interface { Len() int - GetRlp(i int) []byte + EncodeIndex(int, *bytes.Buffer) } -func DeriveSha(list DerivableList) common.Hash { - keybuf := new(bytes.Buffer) - trie := new(trie.Trie) - for i := 0; i < list.Len(); i++ { - keybuf.Reset() - rlp.Encode(keybuf, uint(i)) - trie.Update(keybuf.Bytes(), list.GetRlp(i)) +// Hasher is the tool used to calculate the hash of derivable list. +type Hasher interface { + Reset() + Update([]byte, []byte) error + Hash() common.Hash +} + +func encodeForDerive(list DerivableList, i int, buf *bytes.Buffer) []byte { + buf.Reset() + list.EncodeIndex(i, buf) + // It's really unfortunate that we need to do perform this copy. + // StackTrie holds onto the values until Hash is called, so the values + // written to it must not alias. + return common.CopyBytes(buf.Bytes()) +} + +// DeriveSha creates the tree hashes of transactions, receipts, and withdrawals in a block header. +func DeriveSha(list DerivableList, hasher Hasher) common.Hash { + hasher.Reset() + + valueBuf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(valueBuf) + + // StackTrie requires values to be inserted in increasing hash order, which is not the + // order that `list` provides hashes in. This insertion sequence ensures that the + // order is correct. + // + // The error returned by hasher is omitted because hasher will produce an incorrect + // hash in case any error occurs. + var indexBuf []byte + for i := 1; i < list.Len() && i <= 0x7f; i++ { + indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) + value := encodeForDerive(list, i, valueBuf) + hasher.Update(indexBuf, value) + } + if list.Len() > 0 { + indexBuf = rlp.AppendUint64(indexBuf[:0], 0) + value := encodeForDerive(list, 0, valueBuf) + hasher.Update(indexBuf, value) + } + for i := 0x80; i < list.Len(); i++ { + indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) + value := encodeForDerive(list, i, valueBuf) + hasher.Update(indexBuf, value) } - return trie.Hash() + return hasher.Hash() } diff --git a/core/types/gen_header_rlp.go b/core/types/gen_header_rlp.go new file mode 100644 index 0000000000..1422cf6b16 --- /dev/null +++ b/core/types/gen_header_rlp.go @@ -0,0 +1,58 @@ +// Code generated by rlpgen. DO NOT EDIT. + +//go:build !norlpgen +// +build !norlpgen + +package types + +import ( + "io" + + "github.com/tomochain/tomochain/rlp" +) + +func (obj *Header) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteBytes(obj.ParentHash[:]) + w.WriteBytes(obj.UncleHash[:]) + w.WriteBytes(obj.Coinbase[:]) + w.WriteBytes(obj.Root[:]) + w.WriteBytes(obj.TxHash[:]) + w.WriteBytes(obj.ReceiptHash[:]) + w.WriteBytes(obj.Bloom[:]) + if obj.Difficulty == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Difficulty.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Difficulty) + } + if obj.Number == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Number.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Number) + } + w.WriteUint64(obj.GasLimit) + w.WriteUint64(obj.GasUsed) + if obj.Time == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Time.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Time) + } + w.WriteBytes(obj.Extra) + w.WriteBytes(obj.MixDigest[:]) + w.WriteBytes(obj.Nonce[:]) + w.WriteBytes(obj.Validators) + w.WriteBytes(obj.Validator) + w.WriteBytes(obj.Penalties) + w.ListEnd(_tmp0) + return w.Flush() +} diff --git a/core/types/gen_log_json.go b/core/types/gen_log_json.go index 759ff8814c..ae61caf6b9 100644 --- a/core/types/gen_log_json.go +++ b/core/types/gen_log_json.go @@ -12,6 +12,7 @@ import ( var _ = (*logMarshaling)(nil) +// MarshalJSON marshals as JSON. func (l Log) MarshalJSON() ([]byte, error) { type Log struct { Address common.Address `json:"address" gencodec:"required"` @@ -37,6 +38,7 @@ func (l Log) MarshalJSON() ([]byte, error) { return json.Marshal(&enc) } +// UnmarshalJSON unmarshals from JSON. func (l *Log) UnmarshalJSON(input []byte) error { type Log struct { Address *common.Address `json:"address" gencodec:"required"` diff --git a/core/types/gen_receipt_json.go b/core/types/gen_receipt_json.go index ffc851f2db..8e502c7bc1 100644 --- a/core/types/gen_receipt_json.go +++ b/core/types/gen_receipt_json.go @@ -5,6 +5,7 @@ package types import ( "encoding/json" "errors" + "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" @@ -12,49 +13,66 @@ import ( var _ = (*receiptMarshaling)(nil) +// MarshalJSON marshals as JSON. func (r Receipt) MarshalJSON() ([]byte, error) { type Receipt struct { + Type hexutil.Uint64 `json:"type,omitempty"` PostState hexutil.Bytes `json:"root"` - Status hexutil.Uint `json:"status"` + Status hexutil.Uint64 `json:"status"` CumulativeGasUsed hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` TxHash common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress common.Address `json:"contractAddress"` GasUsed hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + BlockHash common.Hash `json:"blockHash,omitempty"` + BlockNumber *hexutil.Big `json:"blockNumber,omitempty"` + TransactionIndex hexutil.Uint `json:"transactionIndex"` } var enc Receipt + enc.Type = hexutil.Uint64(r.Type) enc.PostState = r.PostState - enc.Status = hexutil.Uint(r.Status) + enc.Status = hexutil.Uint64(r.Status) enc.CumulativeGasUsed = hexutil.Uint64(r.CumulativeGasUsed) enc.Bloom = r.Bloom enc.Logs = r.Logs enc.TxHash = r.TxHash enc.ContractAddress = r.ContractAddress enc.GasUsed = hexutil.Uint64(r.GasUsed) + enc.BlockHash = r.BlockHash + enc.BlockNumber = (*hexutil.Big)(r.BlockNumber) + enc.TransactionIndex = hexutil.Uint(r.TransactionIndex) return json.Marshal(&enc) } +// UnmarshalJSON unmarshals from JSON. func (r *Receipt) UnmarshalJSON(input []byte) error { type Receipt struct { + Type *hexutil.Uint64 `json:"type,omitempty"` PostState *hexutil.Bytes `json:"root"` - Status *hexutil.Uint `json:"status"` + Status *hexutil.Uint64 `json:"status"` CumulativeGasUsed *hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom *Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` TxHash *common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress *common.Address `json:"contractAddress"` GasUsed *hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + BlockHash *common.Hash `json:"blockHash,omitempty"` + BlockNumber *hexutil.Big `json:"blockNumber,omitempty"` + TransactionIndex *hexutil.Uint `json:"transactionIndex"` } var dec Receipt if err := json.Unmarshal(input, &dec); err != nil { return err } + if dec.Type != nil { + r.Type = uint8(*dec.Type) + } if dec.PostState != nil { r.PostState = *dec.PostState } if dec.Status != nil { - r.Status = uint(*dec.Status) + r.Status = uint64(*dec.Status) } if dec.CumulativeGasUsed == nil { return errors.New("missing required field 'cumulativeGasUsed' for Receipt") @@ -79,5 +97,14 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { return errors.New("missing required field 'gasUsed' for Receipt") } r.GasUsed = uint64(*dec.GasUsed) + if dec.BlockHash != nil { + r.BlockHash = *dec.BlockHash + } + if dec.BlockNumber != nil { + r.BlockNumber = (*big.Int)(dec.BlockNumber) + } + if dec.TransactionIndex != nil { + r.TransactionIndex = uint(*dec.TransactionIndex) + } return nil } diff --git a/core/types/gen_tx_json.go b/core/types/gen_tx_json.go deleted file mode 100644 index f43cb04e57..0000000000 --- a/core/types/gen_tx_json.go +++ /dev/null @@ -1,99 +0,0 @@ -// Code generated by github.com/fjl/gencodec. DO NOT EDIT. - -package types - -import ( - "encoding/json" - "errors" - "math/big" - - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/common/hexutil" -) - -var _ = (*txdataMarshaling)(nil) - -func (t txdata) MarshalJSON() ([]byte, error) { - type txdata struct { - AccountNonce hexutil.Uint64 `json:"nonce" gencodec:"required"` - Price *hexutil.Big `json:"gasPrice" gencodec:"required"` - GasLimit hexutil.Uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` - Amount *hexutil.Big `json:"value" gencodec:"required"` - Payload hexutil.Bytes `json:"input" gencodec:"required"` - V *hexutil.Big `json:"v" gencodec:"required"` - R *hexutil.Big `json:"r" gencodec:"required"` - S *hexutil.Big `json:"s" gencodec:"required"` - Hash *common.Hash `json:"hash" rlp:"-"` - } - var enc txdata - enc.AccountNonce = hexutil.Uint64(t.AccountNonce) - enc.Price = (*hexutil.Big)(t.Price) - enc.GasLimit = hexutil.Uint64(t.GasLimit) - enc.Recipient = t.Recipient - enc.Amount = (*hexutil.Big)(t.Amount) - enc.Payload = t.Payload - enc.V = (*hexutil.Big)(t.V) - enc.R = (*hexutil.Big)(t.R) - enc.S = (*hexutil.Big)(t.S) - enc.Hash = t.Hash - return json.Marshal(&enc) -} - -func (t *txdata) UnmarshalJSON(input []byte) error { - type txdata struct { - AccountNonce *hexutil.Uint64 `json:"nonce" gencodec:"required"` - Price *hexutil.Big `json:"gasPrice" gencodec:"required"` - GasLimit *hexutil.Uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` - Amount *hexutil.Big `json:"value" gencodec:"required"` - Payload *hexutil.Bytes `json:"input" gencodec:"required"` - V *hexutil.Big `json:"v" gencodec:"required"` - R *hexutil.Big `json:"r" gencodec:"required"` - S *hexutil.Big `json:"s" gencodec:"required"` - Hash *common.Hash `json:"hash" rlp:"-"` - } - var dec txdata - if err := json.Unmarshal(input, &dec); err != nil { - return err - } - if dec.AccountNonce == nil { - return errors.New("missing required field 'nonce' for txdata") - } - t.AccountNonce = uint64(*dec.AccountNonce) - if dec.Price == nil { - return errors.New("missing required field 'gasPrice' for txdata") - } - t.Price = (*big.Int)(dec.Price) - if dec.GasLimit == nil { - return errors.New("missing required field 'gas' for txdata") - } - t.GasLimit = uint64(*dec.GasLimit) - if dec.Recipient != nil { - t.Recipient = dec.Recipient - } - if dec.Amount == nil { - return errors.New("missing required field 'value' for txdata") - } - t.Amount = (*big.Int)(dec.Amount) - if dec.Payload == nil { - return errors.New("missing required field 'input' for txdata") - } - t.Payload = *dec.Payload - if dec.V == nil { - return errors.New("missing required field 'v' for txdata") - } - t.V = (*big.Int)(dec.V) - if dec.R == nil { - return errors.New("missing required field 'r' for txdata") - } - t.R = (*big.Int)(dec.R) - if dec.S == nil { - return errors.New("missing required field 's' for txdata") - } - t.S = (*big.Int)(dec.S) - if dec.Hash != nil { - t.Hash = dec.Hash - } - return nil -} diff --git a/core/types/hashes.go b/core/types/hashes.go new file mode 100644 index 0000000000..35fc6dc9f9 --- /dev/null +++ b/core/types/hashes.go @@ -0,0 +1,39 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" +) + +var ( + // EmptyRootHash is the known root hash of an empty trie. + EmptyRootHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // EmptyUncleHash is the known hash of the empty uncle set. + EmptyUncleHash = rlpHash([]*Header(nil)) // 1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347 + + // EmptyCodeHash is the known hash of the empty EVM bytecode. + EmptyCodeHash = crypto.Keccak256Hash(nil) // c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470 + + // EmptyTxsHash is the known hash of the empty transaction set. + EmptyTxsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // EmptyReceiptsHash is the known hash of the empty receipt set. + EmptyReceiptsHash = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") +) diff --git a/core/types/hashing.go b/core/types/hashing.go new file mode 100644 index 0000000000..209504435c --- /dev/null +++ b/core/types/hashing.go @@ -0,0 +1,52 @@ +package types + +import ( + "bytes" + "fmt" + "math" + "sync" + + "golang.org/x/crypto/sha3" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/rlp" +) + +// hasherPool holds LegacyKeccak256 hashers for rlpHash. +var hasherPool = sync.Pool{ + New: func() interface{} { return sha3.NewLegacyKeccak256() }, +} + +// encodeBufferPool holds temporary encoder buffers for DeriveSha and TX encoding. +var encodeBufferPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} + +// getPooledBuffer retrieves a buffer from the pool and creates a byte slice of the +// requested size from it. +// +// The caller should return the *bytes.Buffer object back into encodeBufferPool after use! +// The returned byte slice must not be used after returning the buffer. +func getPooledBuffer(size uint64) ([]byte, *bytes.Buffer, error) { + if size > math.MaxInt { + return nil, nil, fmt.Errorf("can't get buffer of size %d", size) + } + buf := encodeBufferPool.Get().(*bytes.Buffer) + buf.Reset() + buf.Grow(int(size)) + b := buf.Bytes()[:int(size)] + return b, buf, nil +} + +// prefixedRlpHash writes the prefix into the hasher before rlp-encoding x. +// It's used for typed transactions. +func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) { + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + sha.Write([]byte{prefix}) + rlp.Encode(sha, x) + sha.Read(h[:]) + return h +} diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go new file mode 100644 index 0000000000..d2f2781a6b --- /dev/null +++ b/core/types/hashing_test.go @@ -0,0 +1,79 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types_test + +import ( + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/trie" +) + +func BenchmarkDeriveSha200(b *testing.B) { + txs, err := genTxs(200) + if err != nil { + b.Fatal(err) + } + var exp common.Hash + var got common.Hash + b.Run("std_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + exp = types.DeriveSha(txs, tr) + } + }) + + b.Run("stack_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + got = types.DeriveSha(txs, trie.NewStackTrie(nil)) + } + }) + if got != exp { + b.Errorf("got %x exp %x", got, exp) + } +} + +func genTxs(num uint64) (types.Transactions, error) { + key, err := crypto.HexToECDSA("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + if err != nil { + return nil, err + } + var addr = crypto.PubkeyToAddress(key.PublicKey) + newTx := func(i uint64) (*types.Transaction, error) { + signer := types.NewEIP155Signer(big.NewInt(18)) + utx := types.NewTransaction(i, addr, new(big.Int), 0, new(big.Int).SetUint64(10000000), nil) + tx, err := types.SignTx(utx, signer, key) + return tx, err + } + var txs types.Transactions + for i := uint64(0); i < num; i++ { + tx, err := newTx(i) + if err != nil { + return nil, err + } + txs = append(txs, tx) + } + return txs, nil +} diff --git a/core/types/lending_transaction.go b/core/types/lending_transaction.go index e33826829c..7152461117 100644 --- a/core/types/lending_transaction.go +++ b/core/types/lending_transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "io" @@ -319,10 +320,12 @@ func (s LendingTransactions) Len() int { return len(s) } // Swap swaps the i'th and the j'th element in s. func (s LendingTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s LendingTransactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s LendingTransactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) } // LendingTxDifference returns a new set t which is the difference between a to b. @@ -363,7 +366,7 @@ func (s *LendingTxByNonce) Pop() interface{} { return x } -//LendingTransactionByNonce sort transaction by nonce +// LendingTransactionByNonce sort transaction by nonce type LendingTransactionByNonce struct { txs map[common.Address]LendingTransactions heads LendingTxByNonce diff --git a/core/types/log.go b/core/types/log.go index af8e515eac..93567b1e61 100644 --- a/core/types/log.go +++ b/core/types/log.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/rlp" ) -//go:generate gencodec -type Log -field-override logMarshaling -out gen_log_json.go +//go:generate go run github.com/fjl/gencodec -type Log -field-override logMarshaling -out gen_log_json.go // Log represents a contract log event. These events are generated by the LOG opcode and // stored/indexed by the node. @@ -63,6 +63,9 @@ type logMarshaling struct { Index hexutil.Uint } +//go:generate go run ../../rlp/rlpgen -type rlpLog -out gen_log_rlp.go + +// rlpLog is used to RLP-encode both the consensus and storage formats. type rlpLog struct { Address common.Address Topics []common.Hash diff --git a/core/types/order_transaction.go b/core/types/order_transaction.go index d51884e3f5..e7150b991e 100644 --- a/core/types/order_transaction.go +++ b/core/types/order_transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "io" @@ -250,10 +251,12 @@ func (s OrderTransactions) Len() int { return len(s) } // Swap swaps the i'th and the j'th element in s. func (s OrderTransactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s OrderTransactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s OrderTransactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.data) } // OrderTxDifference returns a new set t which is the difference between a to b. diff --git a/core/types/receipt.go b/core/types/receipt.go index 3c55c12247..9409bb26d3 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -18,35 +18,42 @@ package types import ( "bytes" + "errors" "fmt" "io" + "math/big" "unsafe" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) -//go:generate gencodec -type Receipt -field-override receiptMarshaling -out gen_receipt_json.go +//go:generate go run github.com/fjl/gencodec -type Receipt -field-override receiptMarshaling -out gen_receipt_json.go var ( receiptStatusFailedRLP = []byte{} receiptStatusSuccessfulRLP = []byte{0x01} ) +var errShortTypedReceipt = errors.New("typed receipt too short") + const ( // ReceiptStatusFailed is the status code of a transaction if execution failed. - ReceiptStatusFailed = uint(0) + ReceiptStatusFailed = uint64(0) // ReceiptStatusSuccessful is the status code of a transaction if execution succeeded. - ReceiptStatusSuccessful = uint(1) + ReceiptStatusSuccessful = uint64(1) ) // Receipt represents the results of a transaction. type Receipt struct { // Consensus fields + Type uint8 `json:"type,omitempty"` PostState []byte `json:"root"` - Status uint `json:"status"` + Status uint64 `json:"status"` CumulativeGasUsed uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` @@ -55,13 +62,22 @@ type Receipt struct { TxHash common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress common.Address `json:"contractAddress"` GasUsed uint64 `json:"gasUsed" gencodec:"required"` + + // Inclusion information: These fields provide information about the inclusion of the + // transaction corresponding to this receipt. + BlockHash common.Hash `json:"blockHash,omitempty"` + BlockNumber *big.Int `json:"blockNumber,omitempty"` + TransactionIndex uint `json:"transactionIndex"` } type receiptMarshaling struct { + Type hexutil.Uint64 PostState hexutil.Bytes - Status hexutil.Uint + Status hexutil.Uint64 CumulativeGasUsed hexutil.Uint64 GasUsed hexutil.Uint64 + BlockNumber *hexutil.Big + TransactionIndex hexutil.Uint } // receiptRLP is the consensus encoding of a receipt. @@ -72,7 +88,14 @@ type receiptRLP struct { Logs []*Log } -type receiptStorageRLP struct { +// StoredReceiptRLP is the storage encoding of a receipt. +type StoredReceiptRLP struct { + PostStateOrStatus []byte + CumulativeGasUsed uint64 + Logs []*Log +} + +type legacyStoredReceiptRLP struct { PostStateOrStatus []byte CumulativeGasUsed uint64 Bloom Bloom @@ -96,21 +119,100 @@ func NewReceipt(root []byte, failed bool, cumulativeGasUsed uint64) *Receipt { // EncodeRLP implements rlp.Encoder, and flattens the consensus fields of a receipt // into an RLP stream. If no post state is present, byzantium fork is assumed. func (r *Receipt) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs}) + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + if r.Type == LegacyTxType { + return rlp.Encode(w, data) + } + buf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(buf) + buf.Reset() + if err := r.encodeTyped(data, buf); err != nil { + return err + } + return rlp.Encode(w, buf.Bytes()) +} + +// encodeTyped writes the canonical encoding of a typed receipt to w. +func (r *Receipt) encodeTyped(data *receiptRLP, w *bytes.Buffer) error { + w.WriteByte(r.Type) + return rlp.Encode(w, data) +} + +// MarshalBinary returns the consensus encoding of the receipt. +func (r *Receipt) MarshalBinary() ([]byte, error) { + if r.Type == LegacyTxType { + return rlp.EncodeToBytes(r) + } + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + var buf bytes.Buffer + err := r.encodeTyped(data, &buf) + return buf.Bytes(), err } // DecodeRLP implements rlp.Decoder, and loads the consensus fields of a receipt // from an RLP stream. func (r *Receipt) DecodeRLP(s *rlp.Stream) error { - var dec receiptRLP - if err := s.Decode(&dec); err != nil { + kind, _, err := s.Kind() + switch { + case err != nil: return err + case kind == rlp.List: + // It's a legacy receipt. + var dec receiptRLP + if err := s.Decode(&dec); err != nil { + return err + } + r.Type = LegacyTxType + return r.setFromRLP(dec) + default: + // It's an EIP-2718 typed tx receipt. + b, err := s.Bytes() + if err != nil { + return err + } + return r.decodeTyped(b) } - if err := r.setStatus(dec.PostStateOrStatus); err != nil { - return err +} + +// UnmarshalBinary decodes the consensus encoding of receipts. +// It supports legacy RLP receipts and EIP-2718 typed receipts. +func (r *Receipt) UnmarshalBinary(b []byte) error { + if len(b) > 0 && b[0] > 0x7f { + // It's a legacy receipt decode the RLP + var data receiptRLP + err := rlp.DecodeBytes(b, &data) + if err != nil { + return err + } + r.Type = LegacyTxType + return r.setFromRLP(data) } - r.CumulativeGasUsed, r.Bloom, r.Logs = dec.CumulativeGasUsed, dec.Bloom, dec.Logs - return nil + // It's an EIP2718 typed transaction envelope. + return r.decodeTyped(b) +} + +// decodeTyped decodes a typed receipt from the canonical format. +func (r *Receipt) decodeTyped(b []byte) error { + if len(b) <= 1 { + return errShortTypedReceipt + } + switch b[0] { + case PaymasterTxType: + var data receiptRLP + err := rlp.DecodeBytes(b[1:], &data) + if err != nil { + return err + } + r.Type = b[0] + return r.setFromRLP(data) + default: + return ErrTxTypeNotSupported + } +} + +func (r *Receipt) setFromRLP(data receiptRLP) error { + r.CumulativeGasUsed, r.Bloom, r.Logs = data.CumulativeGasUsed, data.Bloom, data.Logs + return r.setStatus(data.PostStateOrStatus) } func (r *Receipt) setStatus(postStateOrStatus []byte) error { @@ -141,7 +243,6 @@ func (r *Receipt) statusEncoding() []byte { // to approximate and limit the memory consumption of various caches. func (r *Receipt) Size() common.StorageSize { size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState)) - size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{})) for _, log := range r.Logs { size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data)) @@ -152,9 +253,9 @@ func (r *Receipt) Size() common.StorageSize { // String implements the Stringer interface. func (r *Receipt) String() string { if len(r.PostState) == 0 { - return fmt.Sprintf("receipt{status=%d cgas=%v bloom=%x logs=%v}", r.Status, r.CumulativeGasUsed, r.Bloom, r.Logs) + return fmt.Sprintf("receipt{type=%d status=%d cgas=%v bloom=%x logs=%v}", r.Type, r.Status, r.CumulativeGasUsed, r.Bloom, r.Logs) } - return fmt.Sprintf("receipt{med=%x cgas=%v bloom=%x logs=%v}", r.PostState, r.CumulativeGasUsed, r.Bloom, r.Logs) + return fmt.Sprintf("receipt{type=%d med=%x cgas=%v bloom=%x logs=%v}", r.Type, r.PostState, r.CumulativeGasUsed, r.Bloom, r.Logs) } // ReceiptForStorage is a wrapper around a Receipt that flattens and parses the @@ -163,50 +264,151 @@ type ReceiptForStorage Receipt // EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt // into an RLP stream. -func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { - enc := &receiptStorageRLP{ - PostStateOrStatus: (*Receipt)(r).statusEncoding(), - CumulativeGasUsed: r.CumulativeGasUsed, - Bloom: r.Bloom, - TxHash: r.TxHash, - ContractAddress: r.ContractAddress, - Logs: make([]*LogForStorage, len(r.Logs)), - GasUsed: r.GasUsed, - } - for i, log := range r.Logs { - enc.Logs[i] = (*LogForStorage)(log) +func (r *ReceiptForStorage) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + outerList := w.List() + w.WriteBytes((*Receipt)(r).statusEncoding()) + w.WriteUint64(r.CumulativeGasUsed) + logList := w.List() + for _, log := range r.Logs { + if err := rlp.Encode(w, log); err != nil { + return err + } } - return rlp.Encode(w, enc) + w.ListEnd(logList) + w.ListEnd(outerList) + return w.Flush() } // DecodeRLP implements rlp.Decoder, and loads both consensus and implementation // fields of a receipt from an RLP stream. func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { - var dec receiptStorageRLP - if err := s.Decode(&dec); err != nil { + // Retrieve the entire receipt blob as we need to try multiple decoders + blob, err := s.Raw() + if err != nil { + return err + } + + // Try decoding from the newest format for future proofness, then the older one + // for old nodes that just upgraded. V4 was an intermediate unreleased format so + // we do need to decode it, but it's not common (try last). + if err := decodeStoredReceiptRLP(r, blob); err == nil { + return nil + } + + return decodeLegacyStoredReceiptRLP(r, blob) +} + +func decodeStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored StoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { + return err + } + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { return err } - if err := (*Receipt)(r).setStatus(dec.PostStateOrStatus); err != nil { + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.Logs = stored.Logs + r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) + + return nil +} + +func decodeLegacyStoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored legacyStoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { return err } - // Assign the consensus fields - r.CumulativeGasUsed, r.Bloom = dec.CumulativeGasUsed, dec.Bloom - r.Logs = make([]*Log, len(dec.Logs)) - for i, log := range dec.Logs { + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { + return err + } + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.TxHash = stored.TxHash + r.ContractAddress = stored.ContractAddress + r.GasUsed = stored.GasUsed + r.Logs = make([]*Log, len(stored.Logs)) + for i, log := range stored.Logs { r.Logs[i] = (*Log)(log) } - // Assign the implementation fields - r.TxHash, r.ContractAddress, r.GasUsed = dec.TxHash, dec.ContractAddress, dec.GasUsed + r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) + return nil } -// Receipts is a wrapper around a Receipt array to implement DerivableList. +// Receipts implements DerivableList for receipts. type Receipts []*Receipt // Len returns the number of receipts in this list. -func (r Receipts) Len() int { return len(r) } +func (rs Receipts) Len() int { return len(rs) } + +// EncodeIndex encodes the i'th receipt to w. +func (rs Receipts) EncodeIndex(i int, w *bytes.Buffer) { + r := rs[i] + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + if r.Type == LegacyTxType { + rlp.Encode(w, data) + return + } + w.WriteByte(r.Type) + switch r.Type { + case PaymasterTxType: + rlp.Encode(w, data) + default: + // For unsupported types, write nothing. Since this is for + // DeriveSha, the error will be caught matching the derived hash + // to the block. + } +} + +// DeriveFields fills the receipts with their computed fields based on consensus +// data and contextual infos like containing block and transactions. +func (rs Receipts) DeriveFields(config *params.ChainConfig, hash common.Hash, number uint64, txs []*Transaction) error { + signer := MakeSigner(config, new(big.Int).SetUint64(number)) + + logIndex := uint(0) + if len(txs) != len(rs) { + return errors.New("transaction and receipt count mismatch") + } + for i := 0; i < len(rs); i++ { + // The transaction type and hash can be retrieved from the transaction itself + rs[i].Type = txs[i].Type() + rs[i].TxHash = txs[i].Hash() + + // block location fields + rs[i].BlockHash = hash + rs[i].BlockNumber = new(big.Int).SetUint64(number) + rs[i].TransactionIndex = uint(i) + + // The contract address can be derived from the transaction itself + if txs[i].To() == nil { + // Deriving the signer is expensive, only do if it's actually needed + from, _ := Sender(signer, txs[i]) + rs[i].ContractAddress = crypto.CreateAddress(from, txs[i].Nonce()) + } else { + rs[i].ContractAddress = common.Address{} + } + + // The used gas can be calculated based on previous r + if i == 0 { + rs[i].GasUsed = rs[i].CumulativeGasUsed + } else { + rs[i].GasUsed = rs[i].CumulativeGasUsed - rs[i-1].CumulativeGasUsed + } + + // The derived log fields can simply be set from the block and transaction + for j := 0; j < len(rs[i].Logs); j++ { + rs[i].Logs[j].BlockNumber = number + rs[i].Logs[j].BlockHash = hash + rs[i].Logs[j].TxHash = rs[i].TxHash + rs[i].Logs[j].TxIndex = uint(i) + rs[i].Logs[j].Index = logIndex + logIndex++ + } + } + return nil +} -// GetRlp returns the RLP encoding of one receipt from the list. +// GetRlp returns the RLP encoding of one receipt from the list.. func (r Receipts) GetRlp(i int) []byte { bytes, err := rlp.EncodeToBytes(r[i]) if err != nil { diff --git a/core/types/state_account.go b/core/types/state_account.go new file mode 100644 index 0000000000..01c552a04c --- /dev/null +++ b/core/types/state_account.go @@ -0,0 +1,103 @@ +package types + +import ( + "bytes" + "math/big" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/rlp" +) + +// StateAccount is the Ethereum consensus representation of accounts. +// These objects are stored in the main account trie. +type StateAccount struct { + Nonce uint64 + Balance *big.Int + Root common.Hash // merkle root of the storage trie + CodeHash []byte +} + +// NewEmptyStateAccount constructs an empty state account. +func NewEmptyStateAccount() *StateAccount { + return &StateAccount{ + Balance: new(big.Int), + Root: EmptyRootHash, + CodeHash: EmptyCodeHash.Bytes(), + } +} + +// Copy returns a deep-copied state account object. +func (acct *StateAccount) Copy() *StateAccount { + var balance *big.Int + if acct.Balance != nil { + balance = new(big.Int).Set(acct.Balance) + } + return &StateAccount{ + Nonce: acct.Nonce, + Balance: balance, + Root: acct.Root, + CodeHash: common.CopyBytes(acct.CodeHash), + } +} + +// SlimAccount is a modified version of an Account, where the root is replaced +// with a byte slice. This format can be used to represent full-consensus format +// or slim format which replaces the empty root and code hash as nil byte slice. +type SlimAccount struct { + Nonce uint64 + Balance *big.Int + Root []byte // Nil if root equals to types.EmptyRootHash + CodeHash []byte // Nil if hash equals to types.EmptyCodeHash +} + +// SlimAccountRLP encodes the state account in 'slim RLP' format. +func SlimAccountRLP(account StateAccount) []byte { + slim := SlimAccount{ + Nonce: account.Nonce, + Balance: account.Balance, + } + if account.Root != EmptyRootHash { + slim.Root = account.Root[:] + } + if !bytes.Equal(account.CodeHash, EmptyCodeHash[:]) { + slim.CodeHash = account.CodeHash + } + data, err := rlp.EncodeToBytes(slim) + if err != nil { + panic(err) + } + return data +} + +// FullAccount decodes the data on the 'slim RLP' format and return +// the consensus format account. +func FullAccount(data []byte) (*StateAccount, error) { + var slim SlimAccount + if err := rlp.DecodeBytes(data, &slim); err != nil { + return nil, err + } + var account StateAccount + account.Nonce, account.Balance = slim.Nonce, slim.Balance + + // Interpret the storage root and code hash in slim format. + if len(slim.Root) == 0 { + account.Root = EmptyRootHash + } else { + account.Root = common.BytesToHash(slim.Root) + } + if len(slim.CodeHash) == 0 { + account.CodeHash = EmptyCodeHash[:] + } else { + account.CodeHash = slim.CodeHash + } + return &account, nil +} + +// FullAccountRLP converts data on the 'slim RLP' format into the full RLP-format. +func FullAccountRLP(data []byte) ([]byte, error) { + account, err := FullAccount(data) + if err != nil { + return nil, err + } + return rlp.EncodeToBytes(account) +} diff --git a/core/types/transaction.go b/core/types/transaction.go index cf546c4420..dfddf5ceb0 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -17,6 +17,7 @@ package types import ( + "bytes" "container/heap" "errors" "fmt" @@ -25,8 +26,6 @@ import ( "sync/atomic" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/common/hexutil" - "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" ) @@ -34,6 +33,10 @@ import ( var ( ErrInvalidSig = errors.New("invalid transaction v, r, s values") + ErrUnexpectedProtection = errors.New("transaction type does not supported EIP-155 protected signatures") + ErrInvalidTxType = errors.New("transaction type not valid in this context") + ErrTxTypeNotSupported = errors.New("transaction type not supported") + errShortTypedTx = errors.New("typed transaction too short") errNoSigner = errors.New("missing signing methods") skipNonceDestinationAddress = map[string]bool{ common.TomoXAddr: true, @@ -43,6 +46,12 @@ var ( } ) +// Transaction types. +const ( + LegacyTxType = 0x00 + PaymasterTxType = 0x04 +) + // deriveSigner makes a *best* guess about which signer to use. func deriveSigner(V *big.Int) Signer { if V.Sign() != 0 && isProtectedV(V) { @@ -53,82 +62,63 @@ func deriveSigner(V *big.Int) Signer { } type Transaction struct { - data txdata + inner TxData // caches hash atomic.Value size atomic.Value from atomic.Value } -type txdata struct { - AccountNonce uint64 `json:"nonce" gencodec:"required"` - Price *big.Int `json:"gasPrice" gencodec:"required"` - GasLimit uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` // nil means contract creation - Amount *big.Int `json:"value" gencodec:"required"` - Payload []byte `json:"input" gencodec:"required"` - - // Signature values - V *big.Int `json:"v" gencodec:"required"` - R *big.Int `json:"r" gencodec:"required"` - S *big.Int `json:"s" gencodec:"required"` - - // This is only used when marshaling to JSON. - Hash *common.Hash `json:"hash" rlp:"-"` +// NewTx creates a new transaction. +func NewTx(inner TxData) *Transaction { + tx := new(Transaction) + tx.setDecoded(inner.copy(), 0) + return tx } -type txdataMarshaling struct { - AccountNonce hexutil.Uint64 - Price *hexutil.Big - GasLimit hexutil.Uint64 - Amount *hexutil.Big - Payload hexutil.Bytes - V *hexutil.Big - R *hexutil.Big - S *hexutil.Big -} +// TxData is the underlying inner of a transaction. +// +// This is implemented by LegacyTx and PaymasterTx. +type TxData interface { + txType() byte // returns the type ID + copy() TxData // creates a deep copy and initializes all fields -func NewTransaction(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - return newTransaction(nonce, &to, amount, gasLimit, gasPrice, data) -} + chainID() *big.Int + data() []byte + gas() uint64 + gasPrice() *big.Int + value() *big.Int + nonce() uint64 + to() *common.Address + pmPayload() []byte + + rawSignatureValues() (v, r, s *big.Int) + setSignatureValues(chainID, v, r, s *big.Int) -func NewContractCreation(nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - return newTransaction(nonce, nil, amount, gasLimit, gasPrice, data) + encode(*bytes.Buffer) error + decode([]byte) error } -func newTransaction(nonce uint64, to *common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - if len(data) > 0 { - data = common.CopyBytes(data) - } - d := txdata{ - AccountNonce: nonce, - Recipient: to, - Payload: data, - Amount: new(big.Int), - GasLimit: gasLimit, - Price: new(big.Int), - V: new(big.Int), - R: new(big.Int), - S: new(big.Int), - } - if amount != nil { - d.Amount.Set(amount) - } - if gasPrice != nil { - d.Price.Set(gasPrice) +// Protected says whether the transaction is replay-protected. +func (tx *Transaction) Protected() bool { + switch tx := tx.inner.(type) { + case *LegacyTx: + return tx.V != nil && isProtectedV(tx.V) + default: + return true } - - return &Transaction{data: d} } -// ChainId returns which chain id this transaction was signed for (if at all) -func (tx *Transaction) ChainId() *big.Int { - return deriveChainId(tx.data.V) +// Type returns the transaction type. +func (tx *Transaction) Type() uint8 { + return tx.inner.txType() } -// Protected returns whether the transaction is protected from replay protection. -func (tx *Transaction) Protected() bool { - return isProtectedV(tx.data.V) +// ChainId returns the EIP155 chain ID of the transaction. The return value will always be +// non-nil. For legacy transactions which are not replay-protected, the return value is +// zero. +func (tx *Transaction) ChainId() *big.Int { + return tx.inner.chainID() } func isProtectedV(V *big.Int) bool { @@ -142,68 +132,136 @@ func isProtectedV(V *big.Int) bool { // EncodeRLP implements rlp.Encoder func (tx *Transaction) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &tx.data) + if tx.Type() == LegacyTxType { + return rlp.Encode(w, tx.inner) + } + // It's an EIP-2718 typed TX envelope. + buf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(buf) + buf.Reset() + if err := tx.encodeTyped(buf); err != nil { + return err + } + return rlp.Encode(w, buf.Bytes()) +} + +// encodeTyped writes the canonical encoding of a typed transaction to w. +func (tx *Transaction) encodeTyped(w *bytes.Buffer) error { + w.WriteByte(tx.Type()) + return tx.inner.encode(w) +} + +// MarshalBinary returns the canonical encoding of the transaction. +// For legacy transactions, it returns the RLP encoding. For EIP-2718 typed +// transactions, it returns the type and payload. +func (tx *Transaction) MarshalBinary() ([]byte, error) { + if tx.Type() == LegacyTxType { + return rlp.EncodeToBytes(tx.inner) + } + var buf bytes.Buffer + err := tx.encodeTyped(&buf) + return buf.Bytes(), err } // DecodeRLP implements rlp.Decoder func (tx *Transaction) DecodeRLP(s *rlp.Stream) error { - _, size, _ := s.Kind() - err := s.Decode(&tx.data) - if err == nil { - tx.size.Store(common.StorageSize(rlp.ListSize(size))) + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == rlp.List: + // It's a legacy transaction. + var inner LegacyTx + err := s.Decode(&inner) + if err == nil { + tx.setDecoded(&inner, rlp.ListSize(size)) + } + return err + case kind == rlp.Byte: + return errShortTypedTx + default: + // It's an EIP-2718 typed TX envelope. + // First read the tx payload bytes into a temporary buffer. + b, buf, err := getPooledBuffer(size) + if err != nil { + return err + } + defer encodeBufferPool.Put(buf) + if err := s.ReadBytes(b); err != nil { + return err + } + // Now decode the inner transaction. + inner, err := tx.decodeTyped(b) + if err == nil { + tx.setDecoded(inner, size) + } + return err } - - return err } -// MarshalJSON encodes the web3 RPC transaction format. -func (tx *Transaction) MarshalJSON() ([]byte, error) { - hash := tx.Hash() - data := tx.data - data.Hash = &hash - return data.MarshalJSON() +// UnmarshalBinary decodes the canonical encoding of transactions. +// It supports legacy RLP transactions and EIP-2718 typed transactions. +func (tx *Transaction) UnmarshalBinary(b []byte) error { + if len(b) > 0 && b[0] > 0x7f { + // It's a legacy transaction. + var data LegacyTx + err := rlp.DecodeBytes(b, &data) + if err != nil { + return err + } + tx.setDecoded(&data, uint64(len(b))) + return nil + } + // It's an EIP-2718 typed transaction envelope. + inner, err := tx.decodeTyped(b) + if err != nil { + return err + } + tx.setDecoded(inner, uint64(len(b))) + return nil } -// UnmarshalJSON decodes the web3 RPC transaction format. -func (tx *Transaction) UnmarshalJSON(input []byte) error { - var dec txdata - if err := dec.UnmarshalJSON(input); err != nil { - return err +// decodeTyped decodes a typed transaction from the canonical format. +func (tx *Transaction) decodeTyped(b []byte) (TxData, error) { + if len(b) <= 1 { + return nil, errShortTypedTx } - var V byte - if isProtectedV(dec.V) { - chainID := deriveChainId(dec.V).Uint64() - V = byte(dec.V.Uint64() - 35 - 2*chainID) - } else { - V = byte(dec.V.Uint64() - 27) + var inner TxData + switch b[0] { + case PaymasterTxType: + inner = new(PaymasterTx) + default: + return nil, ErrTxTypeNotSupported } - if !crypto.ValidateSignatureValues(V, dec.R, dec.S, false) { - return ErrInvalidSig + err := inner.decode(b[1:]) + return inner, err +} + +// setDecoded sets the inner transaction and size after decoding. +func (tx *Transaction) setDecoded(inner TxData, size uint64) { + tx.inner = inner + if size > 0 { + tx.size.Store(size) } - *tx = Transaction{data: dec} - return nil } -func (tx *Transaction) Data() []byte { return common.CopyBytes(tx.data.Payload) } -func (tx *Transaction) Gas() uint64 { return tx.data.GasLimit } -func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.data.Price) } -func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.data.Amount) } -func (tx *Transaction) Nonce() uint64 { return tx.data.AccountNonce } -func (tx *Transaction) CheckNonce() bool { return true } +func (tx *Transaction) Data() []byte { return common.CopyBytes(tx.inner.data()) } +func (tx *Transaction) Gas() uint64 { return tx.inner.gas() } +func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.inner.gasPrice()) } +func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.inner.value()) } +func (tx *Transaction) Nonce() uint64 { return tx.inner.nonce() } +func (tx *Transaction) PmPayload() []byte { return tx.inner.pmPayload() } // To returns the recipient address of the transaction. -// It returns nil if the transaction is a contract creation. +// For contract-creation transactions, To returns nil. func (tx *Transaction) To() *common.Address { - if tx.data.Recipient == nil { - return nil - } - to := *tx.data.Recipient - return &to + return copyAddressPtr(tx.inner.to()) } func (tx *Transaction) From() *common.Address { - if tx.data.V != nil { - signer := deriveSigner(tx.data.V) + v, _, _ := tx.RawSignatureValues() + if v != nil { + signer := deriveSigner(v) if f, err := Sender(signer, tx); err != nil { return nil } else { @@ -230,44 +288,26 @@ func (tx *Transaction) CacheHash() { tx.hash.Store(v) } -// Size returns the true RLP encoded storage size of the transaction, either by -// encoding and returning it, or returning a previsouly cached value. -func (tx *Transaction) Size() common.StorageSize { +// Size returns the true encoded storage size of the transaction, either by encoding +// and returning it, or returning a previously cached value. +func (tx *Transaction) Size() uint64 { if size := tx.size.Load(); size != nil { - return size.(common.StorageSize) + return size.(uint64) } + + // Cache miss, encode and cache. + // Note we rely on the assumption that all tx.inner values are RLP-encoded! c := writeCounter(0) - rlp.Encode(&c, &tx.data) - tx.size.Store(common.StorageSize(c)) - return common.StorageSize(c) -} + rlp.Encode(&c, &tx.inner) + size := uint64(c) -// AsMessage returns the transaction as a core.Message. -// -// AsMessage requires a signer to derive the sender. -// -// XXX Rename message to something less arbitrary? -func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) (Message, error) { - msg := Message{ - nonce: tx.data.AccountNonce, - gasLimit: tx.data.GasLimit, - gasPrice: new(big.Int).Set(tx.data.Price), - to: tx.data.Recipient, - amount: tx.data.Amount, - data: tx.data.Payload, - checkNonce: true, - balanceTokenFee: balanceFee, - } - var err error - msg.from, err = Sender(s, tx) - if balanceFee != nil { - if number.Cmp(common.TIPTRC21Fee) > 0 { - msg.gasPrice = common.TRC21GasPrice - } else { - msg.gasPrice = common.TRC21GasPriceBefore - } + // For typed transactions, the encoding also includes the leading type byte. + if tx.Type() != LegacyTxType { + size += 1 } - return msg, err + + tx.size.Store(size) + return size } // WithSignature returns a new transaction with the given signature. @@ -277,27 +317,29 @@ func (tx *Transaction) WithSignature(signer Signer, sig []byte) (*Transaction, e if err != nil { return nil, err } - cpy := &Transaction{data: tx.data} - cpy.data.R, cpy.data.S, cpy.data.V = r, s, v - return cpy, nil + cpy := tx.inner.copy() + cpy.setSignatureValues(signer.ChainID(), v, r, s) + return &Transaction{inner: cpy}, nil } // Cost returns amount + gasprice * gaslimit. func (tx *Transaction) Cost() *big.Int { - total := new(big.Int).Mul(tx.data.Price, new(big.Int).SetUint64(tx.data.GasLimit)) - total.Add(total, tx.data.Amount) + total := new(big.Int).Mul(tx.GasPrice(), new(big.Int).SetUint64(tx.Gas())) + total.Add(total, tx.Value()) return total } -// Cost returns amount + gasprice * gaslimit. +// TRC21Cost returns amount + gasprice * gaslimit. func (tx *Transaction) TRC21Cost() *big.Int { - total := new(big.Int).Mul(common.TRC21GasPrice, new(big.Int).SetUint64(tx.data.GasLimit)) - total.Add(total, tx.data.Amount) + total := new(big.Int).Mul(common.TRC21GasPrice, new(big.Int).SetUint64(tx.Gas())) + total.Add(total, tx.Value()) return total } -func (tx *Transaction) RawSignatureValues() (*big.Int, *big.Int, *big.Int) { - return tx.data.V, tx.data.R, tx.data.S +// RawSignatureValues returns the V, R, S signature values of the transaction. +// The return values should not be modified by the caller. +func (tx *Transaction) RawSignatureValues() (v, r, s *big.Int) { + return tx.inner.rawSignatureValues() } func (tx *Transaction) IsSpecialTransaction() bool { @@ -467,10 +509,11 @@ func (tx *Transaction) IsTomoZApplyTransaction() bool { func (tx *Transaction) String() string { var from, to string - if tx.data.V != nil { + v, r, s := tx.RawSignatureValues() + if v != nil { // make a best guess about the signer and use that to derive // the sender. - signer := deriveSigner(tx.data.V) + signer := deriveSigner(v) if f, err := Sender(signer, tx); err != nil { // derive but don't cache from = "[invalid sender: invalid sig]" } else { @@ -480,39 +523,43 @@ func (tx *Transaction) String() string { from = "[invalid sender: nil V field]" } - if tx.data.Recipient == nil { + if tx.To() == nil { to = "[contract creation]" } else { - to = fmt.Sprintf("%x", tx.data.Recipient[:]) + to = fmt.Sprintf("%x", tx.To()[:]) } - enc, _ := rlp.EncodeToBytes(&tx.data) + enc, _ := rlp.EncodeToBytes(&tx.inner) return fmt.Sprintf(` TX(%x) - Contract: %v - From: %s - To: %s - Nonce: %v - GasPrice: %#x - GasLimit %#x - Value: %#x - Data: 0x%x - V: %#x - R: %#x - S: %#x - Hex: %x + Type: %v + Contract: %v + From: %s + To: %s + Nonce: %v + GasPrice: %#x + GasLimit %#x + Value: %#x + Data: 0x%x + PMPayload: 0x%x + V: %#x + R: %#x + S: %#x + Hex: %x `, - tx.Hash(), - tx.data.Recipient == nil, + tx.Hash().Hex(), + tx.Type(), + tx.To() == nil, from, to, - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Amount, - tx.data.Payload, - tx.data.V, - tx.data.R, - tx.data.S, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.Value(), + tx.Data(), + tx.PmPayload(), + v, + r, + s, enc, ) } @@ -523,15 +570,17 @@ type Transactions []*Transaction // Len returns the length of s. func (s Transactions) Len() int { return len(s) } +// EncodeIndex encodes the i'th transaction to w. Note that this does not check for errors +// because we assume that *Transaction will only ever contain valid txs that were either +// constructed by decoding or via public API in this package. +func (s Transactions) EncodeIndex(i int, w *bytes.Buffer) { + tx := s[i] + rlp.Encode(w, tx.inner) +} + // Swap swaps the i'th and the j'th element in s. func (s Transactions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// GetRlp implements Rlpable and returns the i'th element of s in rlp. -func (s Transactions) GetRlp(i int) []byte { - enc, _ := rlp.EncodeToBytes(s[i]) - return enc -} - // TxDifference returns a new set t which is the difference between a to b. func TxDifference(a, b Transactions) (keep Transactions) { keep = make(Transactions, 0, len(a)) @@ -556,7 +605,7 @@ func TxDifference(a, b Transactions) (keep Transactions) { type TxByNonce Transactions func (s TxByNonce) Len() int { return len(s) } -func (s TxByNonce) Less(i, j int) bool { return s[i].data.AccountNonce < s[j].data.AccountNonce } +func (s TxByNonce) Less(i, j int) bool { return s[i].Nonce() < s[j].Nonce() } func (s TxByNonce) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // TxByPrice implements both the sort and the heap interface, making it useful @@ -568,14 +617,14 @@ type TxByPrice struct { func (s TxByPrice) Len() int { return len(s.txs) } func (s TxByPrice) Less(i, j int) bool { - i_price := s.txs[i].data.Price + i_price := s.txs[i].GasPrice() if s.txs[i].To() != nil { if _, ok := s.payersSwap[*s.txs[i].To()]; ok { i_price = common.TRC21GasPrice } } - j_price := s.txs[j].data.Price + j_price := s.txs[j].GasPrice() if s.txs[j].To() != nil { if _, ok := s.payersSwap[*s.txs[j].To()]; ok { j_price = common.TRC21GasPrice @@ -681,44 +730,11 @@ func (t *TransactionsByPriceAndNonce) Pop() { heap.Pop(&t.heads) } -// Message is a fully derived transaction and implements core.Message -// -// NOTE: In a future PR this will be removed. -type Message struct { - to *common.Address - from common.Address - nonce uint64 - amount *big.Int - gasLimit uint64 - gasPrice *big.Int - data []byte - checkNonce bool - balanceTokenFee *big.Int -} - -func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, checkNonce bool, balanceTokenFee *big.Int) Message { - if balanceTokenFee != nil { - gasPrice = common.TRC21GasPrice - } - return Message{ - from: from, - to: to, - nonce: nonce, - amount: amount, - gasLimit: gasLimit, - gasPrice: gasPrice, - data: data, - checkNonce: checkNonce, - balanceTokenFee: balanceTokenFee, - } -} - -func (m Message) From() common.Address { return m.from } -func (m Message) BalanceTokenFee() *big.Int { return m.balanceTokenFee } -func (m Message) To() *common.Address { return m.to } -func (m Message) GasPrice() *big.Int { return m.gasPrice } -func (m Message) Value() *big.Int { return m.amount } -func (m Message) Gas() uint64 { return m.gasLimit } -func (m Message) Nonce() uint64 { return m.nonce } -func (m Message) Data() []byte { return m.data } -func (m Message) CheckNonce() bool { return m.checkNonce } +// copyAddressPtr copies an address. +func copyAddressPtr(a *common.Address) *common.Address { + if a == nil { + return nil + } + cpy := *a + return &cpy +} diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 5f353a5733..cc79cbbdff 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -42,6 +42,8 @@ type sigCache struct { func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { var signer Signer switch { + case config.IsEIP2718(blockNumber): + signer = NewEIP2718Signer(config.ChainId) case config.IsEIP155(blockNumber): signer = NewEIP155Signer(config.ChainId) case config.IsHomestead(blockNumber): @@ -52,6 +54,39 @@ func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { return signer } +// LatestSigner returns the 'most permissive' Signer available for the given chain +// configuration. Specifically, this enables support of all types of transactions +// when their respective forks are scheduled to occur at any block number (or time) +// in the chain config. +// +// Use this in transaction-handling code where the current block number is unknown. If you +// have the current block number available, use MakeSigner instead. +func LatestSigner(config *params.ChainConfig) Signer { + if config.ChainId != nil { + if config.EIP2718Block != nil { + return NewEIP2718Signer(config.ChainId) + } + if config.EIP155Block != nil { + return NewEIP155Signer(config.ChainId) + } + } + return HomesteadSigner{} +} + +// LatestSignerForChainID returns the 'most permissive' Signer available. Specifically, +// this enables support for EIP-155 replay protection and all implemented EIP-2718 +// transaction types if chainID is non-nil. +// +// Use this in transaction-handling code where the current block number and fork +// configuration are unknown. If you have a ChainConfig, use LatestSigner instead. +// If you have a ChainConfig and know the current block number, use MakeSigner instead. +func LatestSignerForChainID(chainID *big.Int) Signer { + if chainID == nil { + return HomesteadSigner{} + } + return NewEIP2718Signer(chainID) +} + // SignTx signs the transaction using the given signer and private key func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, error) { h := s.Hash(tx) @@ -62,6 +97,17 @@ func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, err return tx.WithSignature(s, sig) } +// SignNewTx creates a transaction and signs it. +func SignNewTx(prv *ecdsa.PrivateKey, s Signer, txdata TxData) (*Transaction, error) { + tx := NewTx(txdata) + h := s.Hash(tx) + sig, err := crypto.Sign(h[:], prv) + if err != nil { + return nil, err + } + return tx.WithSignature(s, sig) +} + // Sender returns the address derived from the signature (V, R, S) using secp256k1 // elliptic curve and an error if it failed deriving or upon an incorrect // signature. @@ -93,15 +139,99 @@ func Sender(signer Signer, tx *Transaction) (common.Address, error) { type Signer interface { // Sender returns the sender address of the transaction. Sender(tx *Transaction) (common.Address, error) + // SignatureValues returns the raw R, S, V values corresponding to the // given signature. - SignatureValues(tx *Transaction, sig []byte) (r, s, v *big.Int, err error) - // Hash returns the hash to be signed. + SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) + ChainID() *big.Int + + // Hash returns 'signature hash', i.e. the transaction hash that is signed by the + // private key. This hash does not uniquely identify the transaction. Hash(tx *Transaction) common.Hash + // Equal returns true if the given signer is the same as the receiver. Equal(Signer) bool } +type eip2718Signer struct{ EIP155Signer } + +// NewEIP2718Signer returns a signer that accepts +// - EIP-2718 paymaster transactions +// - EIP-155 replay protected transactions, and +// - legacy Homestead transactions. +func NewEIP2718Signer(chainId *big.Int) Signer { + return eip2718Signer{NewEIP155Signer(chainId)} +} + +func (s eip2718Signer) Sender(tx *Transaction) (common.Address, error) { + V, R, S := tx.RawSignatureValues() + switch tx.Type() { + case LegacyTxType: + return s.EIP155Signer.Sender(tx) + case PaymasterTxType: + // PM txs are defined to use 0 and 1 as their recovery + // id, add 27 to become equivalent to unprotected Homestead signatures. + V = new(big.Int).Add(V, big.NewInt(27)) + default: + return common.Address{}, ErrTxTypeNotSupported + } + if tx.ChainId().Cmp(s.chainId) != 0 { + return common.Address{}, fmt.Errorf("EIP2718Signer %w: have %d want %d", ErrInvalidChainId, tx.ChainId(), s.chainId) + } + return recoverPlain(s.Hash(tx), R, S, V, true) +} + +func (s eip2718Signer) Equal(s2 Signer) bool { + x, ok := s2.(eip2718Signer) + return ok && x.chainId.Cmp(s.chainId) == 0 +} + +func (s eip2718Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { + switch txdata := tx.inner.(type) { + case *LegacyTx: + return s.EIP155Signer.SignatureValues(tx, sig) + case *PaymasterTx: + // Check that chain ID of tx matches the signer. We also accept ID zero here, + // because it indicates that the chain ID was not specified in the tx. + if txdata.ChainID.Sign() != 0 && txdata.ChainID.Cmp(s.chainId) != 0 { + return nil, nil, nil, fmt.Errorf("EIP2718Signer %w: have %d want %d", ErrInvalidChainId, txdata.ChainID, s.chainId) + } + R, S, _ = decodeSignature(sig) + V = big.NewInt(int64(sig[64])) + default: + return nil, nil, nil, ErrTxTypeNotSupported + } + return R, S, V, nil +} + +// Hash returns the hash to be signed by the sender. +// It does not uniquely identify the transaction. +func (s eip2718Signer) Hash(tx *Transaction) common.Hash { + switch tx.Type() { + case LegacyTxType: + return s.EIP155Signer.Hash(tx) + case PaymasterTxType: + return prefixedRlpHash( + tx.Type(), + []interface{}{ + s.chainId, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + tx.PmPayload(), + }) + default: + // This _should_ not happen, but in case someone sends in a bad + // json struct via RPC, it's probably more prudent to return an + // empty hash instead of killing the node with a panic + //panic("Unsupported transaction type: %d", tx.typ) + return common.Hash{} + } +} + // EIP155Transaction implements Signer using the EIP155 rules. type EIP155Signer struct { chainId, chainIdMul *big.Int @@ -117,6 +247,10 @@ func NewEIP155Signer(chainId *big.Int) EIP155Signer { } } +func (s EIP155Signer) ChainID() *big.Int { + return s.chainId +} + func (s EIP155Signer) Equal(s2 Signer) bool { eip155, ok := s2.(EIP155Signer) return ok && eip155.chainId.Cmp(s.chainId) == 0 @@ -125,15 +259,19 @@ func (s EIP155Signer) Equal(s2 Signer) bool { var big8 = big.NewInt(8) func (s EIP155Signer) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } if !tx.Protected() { return HomesteadSigner{}.Sender(tx) } if tx.ChainId().Cmp(s.chainId) != 0 { - return common.Address{}, ErrInvalidChainId + return common.Address{}, fmt.Errorf("EIP155Signer %w: have %d want %d", ErrInvalidChainId, tx.ChainId(), s.chainId) } - V := new(big.Int).Sub(tx.data.V, s.chainIdMul) + V, R, S := tx.RawSignatureValues() + V = new(big.Int).Sub(V, s.chainIdMul) V.Sub(V, big8) - return recoverPlain(s.Hash(tx), tx.data.R, tx.data.S, V, true) + return recoverPlain(s.Hash(tx), R, S, V, true) } // WithSignature returns a new transaction with the given signature. This signature @@ -154,12 +292,12 @@ func (s EIP155Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big // It does not uniquely identify the transaction. func (s EIP155Signer) Hash(tx *Transaction) common.Hash { return rlpHash([]interface{}{ - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Recipient, - tx.data.Amount, - tx.data.Payload, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), s.chainId, uint(0), uint(0), }) } @@ -168,6 +306,10 @@ func (s EIP155Signer) Hash(tx *Transaction) common.Hash { // homestead rules. type HomesteadSigner struct{ FrontierSigner } +func (s HomesteadSigner) ChainID() *big.Int { + return nil +} + func (s HomesteadSigner) Equal(s2 Signer) bool { _, ok := s2.(HomesteadSigner) return ok @@ -180,11 +322,19 @@ func (hs HomesteadSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v } func (hs HomesteadSigner) Sender(tx *Transaction) (common.Address, error) { - return recoverPlain(hs.Hash(tx), tx.data.R, tx.data.S, tx.data.V, true) + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } + v, r, s := tx.RawSignatureValues() + return recoverPlain(hs.Hash(tx), r, s, v, true) } type FrontierSigner struct{} +func (s FrontierSigner) ChainID() *big.Int { + return nil +} + func (s FrontierSigner) Equal(s2 Signer) bool { _, ok := s2.(FrontierSigner) return ok @@ -193,12 +343,10 @@ func (s FrontierSigner) Equal(s2 Signer) bool { // SignatureValues returns signature values. This signature // needs to be in the [R || S || V] format where V is 0 or 1. func (fs FrontierSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v *big.Int, err error) { - if len(sig) != 65 { - panic(fmt.Sprintf("wrong size for signature: got %d, want 65", len(sig))) + if tx.Type() != LegacyTxType { + return nil, nil, nil, ErrTxTypeNotSupported } - r = new(big.Int).SetBytes(sig[:32]) - s = new(big.Int).SetBytes(sig[32:64]) - v = new(big.Int).SetBytes([]byte{sig[64] + 27}) + r, s, v = decodeSignature(sig) return r, s, v, nil } @@ -206,17 +354,31 @@ func (fs FrontierSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v * // It does not uniquely identify the transaction. func (fs FrontierSigner) Hash(tx *Transaction) common.Hash { return rlpHash([]interface{}{ - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Recipient, - tx.data.Amount, - tx.data.Payload, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), }) } func (fs FrontierSigner) Sender(tx *Transaction) (common.Address, error) { - return recoverPlain(fs.Hash(tx), tx.data.R, tx.data.S, tx.data.V, false) + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } + v, r, s := tx.RawSignatureValues() + return recoverPlain(fs.Hash(tx), r, s, v, false) +} + +func decodeSignature(sig []byte) (r, s, v *big.Int) { + if len(sig) != crypto.SignatureLength { + panic(fmt.Sprintf("wrong size for signature: got %d, want %d", len(sig), crypto.SignatureLength)) + } + r = new(big.Int).SetBytes(sig[:32]) + s = new(big.Int).SetBytes(sig[32:64]) + v = new(big.Int).SetBytes([]byte{sig[64] + 27}) + return r, s, v } func recoverPlain(sighash common.Hash, R, S, Vb *big.Int, homestead bool) (common.Address, error) { diff --git a/core/types/transaction_signing_test.go b/core/types/transaction_signing_test.go index e538ee3b27..535ee13990 100644 --- a/core/types/transaction_signing_test.go +++ b/core/types/transaction_signing_test.go @@ -17,6 +17,7 @@ package types import ( + "errors" "math/big" "testing" @@ -127,8 +128,8 @@ func TestChainId(t *testing.T) { } _, err = Sender(NewEIP155Signer(big.NewInt(2)), tx) - if err != ErrInvalidChainId { - t.Error("expected error:", ErrInvalidChainId) + if !errors.Is(err, ErrInvalidChainId) { + t.Error("expected error:", ErrInvalidChainId, err) } _, err = Sender(NewEIP155Signer(big.NewInt(1)), tx) diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index bc8195c986..20fa6132f2 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -19,7 +19,6 @@ package types import ( "bytes" "crypto/ecdsa" - "encoding/json" "math/big" "testing" @@ -191,45 +190,3 @@ func TestTransactionPriceNonceSort(t *testing.T) { } } } - -// TestTransactionJSON tests serializing/de-serializing to/from JSON. -func TestTransactionJSON(t *testing.T) { - key, err := crypto.GenerateKey() - if err != nil { - t.Fatalf("could not generate key: %v", err) - } - signer := NewEIP155Signer(common.Big1) - - for i := uint64(0); i < 25; i++ { - var tx *Transaction - switch i % 2 { - case 0: - tx = NewTransaction(i, common.Address{1}, common.Big0, 1, common.Big2, []byte("abcdef")) - case 1: - tx = NewContractCreation(i, common.Big0, 1, common.Big2, []byte("abcdef")) - } - - tx, err := SignTx(tx, signer, key) - if err != nil { - t.Fatalf("could not sign transaction: %v", err) - } - - data, err := json.Marshal(tx) - if err != nil { - t.Errorf("json.Marshal failed: %v", err) - } - - var parsedTx *Transaction - if err := json.Unmarshal(data, &parsedTx); err != nil { - t.Errorf("json.Unmarshal failed: %v", err) - } - - // compare nonce, price, gaslimit, recipient, amount, payload, V, R, S - if tx.Hash() != parsedTx.Hash() { - t.Errorf("parsed tx differs from original tx, want %v, got %v", tx, parsedTx) - } - if tx.ChainId().Cmp(parsedTx.ChainId()) != 0 { - t.Errorf("invalid chain id, want %d, got %d", tx.ChainId(), parsedTx.ChainId()) - } - } -} diff --git a/core/types/tx_legacy.go b/core/types/tx_legacy.go new file mode 100644 index 0000000000..84bd1b51f4 --- /dev/null +++ b/core/types/tx_legacy.go @@ -0,0 +1,119 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "bytes" + "math/big" + + "github.com/tomochain/tomochain/common" +) + +// LegacyTx is the transaction inner of the original Ethereum transactions. +type LegacyTx struct { + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input inner + V, R, S *big.Int // signature values +} + +// NewTransaction creates an unsigned legacy transaction. +// Deprecated: use NewTx instead. +func NewTransaction(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { + return NewTx(&LegacyTx{ + Nonce: nonce, + To: &to, + Value: amount, + Gas: gasLimit, + GasPrice: gasPrice, + Data: data, + }) +} + +// NewContractCreation creates an unsigned legacy transaction. +// Deprecated: use NewTx instead. +func NewContractCreation(nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { + return NewTx(&LegacyTx{ + Nonce: nonce, + Value: amount, + Gas: gasLimit, + GasPrice: gasPrice, + Data: data, + }) +} + +// copy creates a deep copy of the transaction data and initializes all fields. +func (tx *LegacyTx) copy() TxData { + cpy := &LegacyTx{ + Nonce: tx.Nonce, + To: copyAddressPtr(tx.To), + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are initialized below. + Value: new(big.Int), + GasPrice: new(big.Int), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +// accessors for innerTx. +func (tx *LegacyTx) txType() byte { return LegacyTxType } +func (tx *LegacyTx) chainID() *big.Int { return deriveChainId(tx.V) } +func (tx *LegacyTx) data() []byte { return tx.Data } +func (tx *LegacyTx) gas() uint64 { return tx.Gas } +func (tx *LegacyTx) gasPrice() *big.Int { return tx.GasPrice } +func (tx *LegacyTx) value() *big.Int { return tx.Value } +func (tx *LegacyTx) nonce() uint64 { return tx.Nonce } +func (tx *LegacyTx) to() *common.Address { return tx.To } +func (tx *LegacyTx) pmPayload() []byte { return nil } + +func (tx *LegacyTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *LegacyTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.V, tx.R, tx.S = v, r, s +} + +func (tx *LegacyTx) encode(*bytes.Buffer) error { + panic("encode called on LegacyTx") +} + +func (tx *LegacyTx) decode([]byte) error { + panic("decode called on LegacyTx)") +} diff --git a/core/types/tx_paymaster.go b/core/types/tx_paymaster.go new file mode 100644 index 0000000000..fbcdc945d9 --- /dev/null +++ b/core/types/tx_paymaster.go @@ -0,0 +1,87 @@ +package types + +import ( + "bytes" + "math/big" + + "github.com/tomochain/tomochain/rlp" + + "github.com/tomochain/tomochain/common" +) + +// PaymasterTx indicates the transactions which can be sponsored gas by a custom paymaster contract. +type PaymasterTx struct { + ChainID *big.Int // destination chain ID + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input inner + PmPayload []byte // Payload for calling paymaster contracts = PM contract address (required) + custom payload (if any) + V, R, S *big.Int // signature values +} + +// copy creates a deep copy of the transaction data and initializes all fields. +func (tx *PaymasterTx) copy() TxData { + cpy := &PaymasterTx{ + Nonce: tx.Nonce, + To: copyAddressPtr(tx.To), + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are initialized below. + Value: new(big.Int), + GasPrice: new(big.Int), + ChainID: new(big.Int), + PmPayload: common.CopyBytes(tx.PmPayload), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.ChainID != nil { + cpy.ChainID.Set(tx.ChainID) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +// accessors for innerTx. +func (tx *PaymasterTx) txType() byte { return PaymasterTxType } +func (tx *PaymasterTx) chainID() *big.Int { return tx.ChainID } +func (tx *PaymasterTx) data() []byte { return tx.Data } +func (tx *PaymasterTx) gas() uint64 { return tx.Gas } +func (tx *PaymasterTx) gasPrice() *big.Int { return tx.GasPrice } +func (tx *PaymasterTx) value() *big.Int { return tx.Value } +func (tx *PaymasterTx) nonce() uint64 { return tx.Nonce } +func (tx *PaymasterTx) to() *common.Address { return tx.To } +func (tx *PaymasterTx) pmPayload() []byte { return tx.PmPayload } + +func (tx *PaymasterTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *PaymasterTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.V, tx.R, tx.S = v, r, s +} + +func (tx *PaymasterTx) encode(b *bytes.Buffer) error { + return rlp.Encode(b, tx) +} + +func (tx *PaymasterTx) decode(input []byte) error { + return rlp.DecodeBytes(input, tx) +} diff --git a/core/types/types_test.go b/core/types/types_test.go new file mode 100644 index 0000000000..03c29a159b --- /dev/null +++ b/core/types/types_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/rlp" +) + +type devnull struct{ len int } + +func (d *devnull) Write(p []byte) (int, error) { + d.len += len(p) + return len(p), nil +} + +func BenchmarkEncodeRLP(b *testing.B) { + benchRLP(b, true) +} + +func BenchmarkDecodeRLP(b *testing.B) { + benchRLP(b, false) +} + +func benchRLP(b *testing.B, encode bool) { + key, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + to := common.HexToAddress("0x00000000000000000000000000000000deadbeef") + signer := NewEIP155Signer(big.NewInt(1337)) + tx := NewTransaction(1, to, big.NewInt(1), 1000000, big.NewInt(500), nil) + signedTx, err := SignTx(tx, signer, key) + if err != nil { + b.Fatal("cannot sign transaction for benchmarking") + } + for _, tc := range []struct { + name string + obj interface{} + }{ + { + "header", + &Header{ + Difficulty: big.NewInt(10000000000), + Number: big.NewInt(1000), + GasLimit: 8_000_000, + GasUsed: 8_000_000, + Time: big.NewInt(555), + Extra: make([]byte, 32), + }, + }, + { + "receipt-for-storage", + &ReceiptForStorage{ + Status: ReceiptStatusSuccessful, + CumulativeGasUsed: 0x888888888, + Logs: make([]*Log, 0), + }, + }, + { + "receipt-full", + &Receipt{ + Status: ReceiptStatusSuccessful, + CumulativeGasUsed: 0x888888888, + Logs: make([]*Log, 0), + }, + }, + { + "transaction", + signedTx, + }, + } { + if encode { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + var null = &devnull{} + for i := 0; i < b.N; i++ { + rlp.Encode(null, tc.obj) + } + b.SetBytes(int64(null.len / b.N)) + }) + } else { + data, _ := rlp.EncodeToBytes(tc.obj) + // Test decoding + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if err := rlp.DecodeBytes(data, tc.obj); err != nil { + b.Fatal(err) + } + } + b.SetBytes(int64(len(data))) + }) + } + } +} diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go index ba31cf4945..7e7df4f891 100644 --- a/core/vm/gas_table_test.go +++ b/core/vm/gas_table_test.go @@ -17,11 +17,12 @@ package vm import ( - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/core/state" @@ -81,7 +82,7 @@ func TestEIP2200(t *testing.T) { for i, tt := range eip2200Tests { address := common.BytesToAddress([]byte("contract")) db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) statedb.CreateAccount(address) statedb.SetCode(address, hexutil.MustDecode(tt.input)) statedb.SetState(address, common.Hash{}, common.BytesToHash([]byte{tt.original})) @@ -91,7 +92,7 @@ func TestEIP2200(t *testing.T) { CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true }, Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, } - vmenv := NewEVM(vmctx, statedb, nil,params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}}) + vmenv := NewEVM(vmctx, statedb, nil, params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}}) _, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, tt.gaspool, new(big.Int)) if err != tt.failure { diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 16f3685852..ab962bd65d 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -17,13 +17,13 @@ package vm import ( - "github.com/tomochain/tomochain/params" "math/big" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core/types" - "golang.org/x/crypto/sha3" + "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/params" ) var ( @@ -381,7 +381,7 @@ func opSha3(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by data := callContext.memory.GetPtr(offset.Int64(), size.Int64()) if interpreter.hasher == nil { - interpreter.hasher = sha3.NewLegacyKeccak256().(keccakState) + interpreter.hasher = crypto.NewKeccakState() } else { interpreter.hasher.Reset() } @@ -513,16 +513,21 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // opExtCodeHash returns the code hash of a specified account. // There are several cases when the function is called, while we can relay everything // to `state.GetCodeHash` function to ensure the correctness. -// (1) Caller tries to get the code hash of a normal contract account, state +// +// (1) Caller tries to get the code hash of a normal contract account, state +// // should return the relative code hash and set it as the result. // -// (2) Caller tries to get the code hash of a non-existent account, state should +// (2) Caller tries to get the code hash of a non-existent account, state should +// // return common.Hash{} and zero will be set as the result. // -// (3) Caller tries to get the code hash for an account without contract code, +// (3) Caller tries to get the code hash for an account without contract code, +// // state should return emptyCodeHash(0xc5d246...) as the result. // -// (4) Caller tries to get the code hash of a precompiled account, the result +// (4) Caller tries to get the code hash of a precompiled account, the result +// // should be zero or emptyCodeHash. // // It is worth noting that in order to avoid unnecessary create and clean, @@ -531,10 +536,12 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // If the precompile account is not transferred any amount on a private or // customized chain, the return value will be zero. // -// (5) Caller tries to get the code hash for an account which is marked as suicided +// (5) Caller tries to get the code hash for an account which is marked as suicided +// // in the current transaction, the code hash of this account should be returned. // -// (6) Caller tries to get the code hash for an account which is marked as deleted, +// (6) Caller tries to get the code hash for an account which is marked as deleted, +// // this account should be regarded as a non-existent account and zero should be returned. func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { slot := callContext.stack.peek() diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index fc5b17a4f3..36027be797 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -17,11 +17,11 @@ package vm import ( - "hash" "sync/atomic" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/log" ) @@ -70,14 +70,6 @@ type callCtx struct { contract *Contract } -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - // EVMInterpreter represents an EVM interpreter type EVMInterpreter struct { evm *EVM @@ -85,8 +77,8 @@ type EVMInterpreter struct { intPool *intPool - hasher keccakState // Keccak256 hasher instance shared across opcodes - hasherBuf common.Hash // Keccak256 hasher result array shared aross opcodes + hasher crypto.KeccakState // Keccak256 hasher instance shared across opcodes + hasherBuf common.Hash // Keccak256 hasher result array shared across opcodes readOnly bool // Whether to throw on stateful modifications returnData []byte // Last CALL's return data for subsequent reuse diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index 683cad1d1c..9a13d3d6f6 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -17,11 +17,12 @@ package runtime import ( - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/vm" @@ -100,7 +101,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { if cfg.State == nil { db := rawdb.NewMemoryDatabase() - cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) } var ( address = common.BytesToAddress([]byte("contract")) @@ -131,7 +132,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { if cfg.State == nil { db := rawdb.NewMemoryDatabase() - cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db)) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) } var ( vmenv = NewEnv(cfg) diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index e430c2b2ae..0b95751d34 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -17,11 +17,12 @@ package runtime import ( - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" @@ -99,7 +100,7 @@ func TestExecute(t *testing.T) { func TestCall(t *testing.T) { db := rawdb.NewMemoryDatabase() - state, _ := state.New(common.Hash{}, state.NewDatabase(db)) + state, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) address := common.HexToAddress("0x0a") state.SetCode(address, []byte{ byte(vm.PUSH1), 10, @@ -156,7 +157,7 @@ func BenchmarkCall(b *testing.B) { func benchmarkEVM_Create(bench *testing.B, code string) { var ( db = rawdb.NewMemoryDatabase() - statedb, _ = state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) sender = common.BytesToAddress([]byte("sender")) receiver = common.BytesToAddress([]byte("receiver")) ) diff --git a/crypto/crypto.go b/crypto/crypto.go index 18386f85c0..6affee64ce 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -23,6 +23,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash" "io" "io/ioutil" "math/big" @@ -30,38 +31,72 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" - "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/rlp" + "golang.org/x/crypto/sha3" ) +// SignatureLength indicates the byte length required to carry a signature with recovery id. +const SignatureLength = 64 + 1 // 64 bytes ECDSA signature + 1 byte recovery id + +// RecoveryIDOffset points to the byte offset within the signature that contains the recovery id. +const RecoveryIDOffset = 64 + +// DigestLength sets the signature digest exact length +const DigestLength = 32 + var ( secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2)) ) +var errInvalidPubkey = errors.New("invalid secp256k1 public key") + +// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports +// Read to get a variable amount of data from the hash state. Read is faster than Sum +// because it doesn't copy the internal state, but also modifies the internal state. +type KeccakState interface { + hash.Hash + Read([]byte) (int, error) +} + +// NewKeccakState creates a new KeccakState +func NewKeccakState() KeccakState { + return sha3.NewLegacyKeccak256().(KeccakState) +} + +// HashData hashes the provided data using the KeccakState and returns a 32 byte hash +func HashData(kh KeccakState, data []byte) (h common.Hash) { + kh.Reset() + kh.Write(data) + kh.Read(h[:]) + return h +} + // Keccak256 calculates and returns the Keccak256 hash of the input data. func Keccak256(data ...[]byte) []byte { - d := sha3.NewKeccak256() + b := make([]byte, 32) + d := NewKeccakState() for _, b := range data { d.Write(b) } - return d.Sum(nil) + d.Read(b) + return b } // Keccak256Hash calculates and returns the Keccak256 hash of the input data, // converting it to an internal Hash data structure. func Keccak256Hash(data ...[]byte) (h common.Hash) { - d := sha3.NewKeccak256() + d := NewKeccakState() for _, b := range data { d.Write(b) } - d.Sum(h[:0]) + d.Read(h[:]) return h } // Keccak512 calculates and returns the Keccak512 hash of the input data. func Keccak512(data ...[]byte) []byte { - d := sha3.NewKeccak512() + d := sha3.NewLegacyKeccak512() for _, b := range data { d.Write(b) } diff --git a/eth/api.go b/eth/api.go index 76a466a49f..e885f6d600 100644 --- a/eth/api.go +++ b/eth/api.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" @@ -343,7 +344,7 @@ func NewPrivateDebugAPI(config *params.ChainConfig, eth *Ethereum) *PrivateDebug // Preimage is a debug API function that returns the preimage for a sha3 hash, if known. func (api *PrivateDebugAPI) Preimage(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { - db := core.PreimageTable(api.eth.ChainDb()) + db := rawdb.PreimageTable(api.eth.ChainDb()) return db.Get(hash.Bytes()) } @@ -494,11 +495,10 @@ func (api *PublicEthereumAPI) ChainId() hexutil.Uint64 { } // GetOwner return masternode owner of the given coinbase address -func (api *PublicEthereumAPI) GetOwnerByCoinbase(ctx context.Context, coinbase common.Address, blockNr rpc.BlockNumber) (common.Address, error) { +func (api *PublicEthereumAPI) GetOwnerByCoinbase(ctx context.Context, coinbase common.Address, blockNr rpc.BlockNumber) (common.Address, error) { statedb, _, err := api.e.ApiBackend.StateAndHeaderByNumber(ctx, blockNr) if err != nil { return common.Address{}, err } return statedb.GetOwner(coinbase), nil } - diff --git a/eth/api_backend.go b/eth/api_backend.go index 67554b4480..ddd44455ed 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -21,23 +21,19 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - - "github.com/tomochain/tomochain/consensus/posv" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" + "github.com/tomochain/tomochain/consensus/posv" "github.com/tomochain/tomochain/contracts" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" stateDatabase "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" @@ -50,6 +46,9 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) // EthApiBackend implements ethapi.Backend for full nodes @@ -117,11 +116,11 @@ func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t } func (b *EthApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - return core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)), nil + return rawdb.GetBlockReceipts(b.eth.chainDb, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()), nil } func (b *EthApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - receipts := core.GetBlockReceipts(b.eth.chainDb, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + receipts := rawdb.GetBlockReceipts(b.eth.chainDb, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()) if receipts == nil { return nil, nil } @@ -136,8 +135,8 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - state.SetBalance(msg.From(), math.MaxBig256) +func (b *EthApiBackend) GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From, math.MaxBig256) vmError := func() error { return nil } context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) diff --git a/eth/api_test.go b/eth/api_test.go index f9f2fc43da..f0a48df523 100644 --- a/eth/api_test.go +++ b/eth/api_test.go @@ -17,10 +17,11 @@ package eth import ( - "github.com/tomochain/tomochain/core/rawdb" "reflect" "testing" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/davecgh/go-spew/spew" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/state" @@ -32,7 +33,7 @@ func TestStorageRangeAt(t *testing.T) { // Create a state where account 0x010000... has a few storage entries. var ( db = rawdb.NewMemoryDatabase() - state, _ = state.New(common.Hash{}, state.NewDatabase(db)) + state, _ = state.New(common.Hash{}, state.NewDatabase(db), nil) addr = common.Address{0x01} keys = []common.Hash{ // hashes of Keys of storage common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), diff --git a/eth/api_tracer.go b/eth/api_tracer.go index e1744dc2c1..b941597279 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -21,7 +21,6 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/tomox/tradingstate" "io/ioutil" "math/big" "runtime" @@ -31,6 +30,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -39,6 +39,7 @@ import ( "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" "github.com/tomochain/tomochain/trie" ) @@ -144,7 +145,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl return nil, fmt.Errorf("parent block #%d not found", number-1) } } - statedb, err := state.New(start.Root(), database) + statedb, err := state.New(start.Root(), database, nil) var tomoxState *tradingstate.TradingStateDB if err != nil { // If the starting state is missing, allow some number of blocks to be reexecuted @@ -158,7 +159,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl if start == nil { break } - if statedb, err = state.New(start.Root(), database); err == nil { + if statedb, err = state.New(start.Root(), database, nil); err == nil { tomoxState, err = tradingstate.New(start.Root(), tradingstate.NewDatabase(api.eth.TomoX.GetLevelDB())) if err == nil { break @@ -198,13 +199,13 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl feeCapacity := state.GetTRC21FeeCapacityFromState(task.statedb) // Trace all the transactions contained within for i, tx := range task.block.Transactions() { - var balacne *big.Int + var balanceFee *big.Int if tx.To() != nil { if value, ok := feeCapacity[*tx.To()]; ok { - balacne = value + balanceFee = value } } - msg, _ := tx.AsMessage(signer, balacne, task.block.Number()) + msg, _ := core.TransactionToMessage(tx, signer, balanceFee, task.block.Number()) vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) @@ -438,13 +439,13 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, // Fetch and execute the next transaction trace tasks for task := range jobs { feeCapacity := state.GetTRC21FeeCapacityFromState(task.statedb) - var balacne *big.Int + var balanceFee *big.Int if txs[task.index].To() != nil { if value, ok := feeCapacity[*txs[task.index].To()]; ok { - balacne = value + balanceFee = value } } - msg, _ := txs[task.index].AsMessage(signer, balacne, block.Number()) + msg, _ := core.TransactionToMessage(txs[task.index], signer, balanceFee, block.Number()) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) @@ -462,19 +463,19 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, for i, tx := range txs { // Send the trace task over for execution jobs <- &txTraceTask{statedb: statedb.Copy(), index: i} - var balacne *big.Int + var balanceFee *big.Int if tx.To() != nil { if value, ok := feeCapacity[*tx.To()]; ok { - balacne = value + balanceFee = value } } // Generate the next state snapshot fast without tracing - msg, _ := tx.AsMessage(signer, balacne, block.Number()) + msg, _ := core.TransactionToMessage(tx, signer, balanceFee, block.Number()) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) vmenv := vm.NewEVM(vmctx, statedb, tomoxState, api.config, vm.Config{}) owner := common.Address{} - if _, _, _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner); err != nil { + if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.GasLimit), owner); err != nil { failed = err break } @@ -513,7 +514,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* if block == nil { break } - if statedb, err = state.New(block.Root(), database); err == nil { + if statedb, err = state.New(block.Root(), database, nil); err == nil { tomoxState, err = tradingstate.New(block.Root(), tradingstate.NewDatabase(api.eth.TomoX.GetLevelDB())) if err == nil { break @@ -567,14 +568,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* } size, _ := database.TrieDB().Size() log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", size) - return statedb,tomoxState, nil + return statedb, tomoxState, nil } // TraceTransaction returns the structured logs created during the execution of EVM // and returns them as a JSON object. func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Hash, config *TraceConfig) (interface{}, error) { // Retrieve the transaction and assemble its EVM context - tx, blockHash, _, index := core.GetTransaction(api.eth.ChainDb(), hash) + tx, blockHash, _, index := rawdb.GetTransaction(api.eth.ChainDb(), hash) if tx == nil { return nil, fmt.Errorf("transaction %x not found", hash) } @@ -593,7 +594,7 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Ha // traceTx configures a new tracer according to the provided configuration, and // executes the given message in the provided environment. The return value will // be tracer dependent. -func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { +func (api *PrivateDebugAPI) traceTx(ctx context.Context, message *core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { // Assemble the structured logger or the JavaScript tracer var ( tracer vm.Tracer @@ -630,7 +631,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v vmenv := vm.NewEVM(vmctx, statedb, nil, api.config, vm.Config{Debug: true, Tracer: tracer}) owner := common.Address{} - ret, gas, failed, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner) + result, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.GasLimit), owner) if err != nil { return nil, fmt.Errorf("tracing failed: %v", err) } @@ -638,9 +639,9 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v switch tracer := tracer.(type) { case *vm.StructLogger: return ðapi.ExecutionResult{ - Gas: gas, - Failed: failed, - ReturnValue: fmt.Sprintf("%x", ret), + Gas: result.UsedGas, + Failed: result.Failed(), + ReturnValue: fmt.Sprintf("%x", result.Return()), StructLogs: ethapi.FormatLogs(tracer.StructLogs()), }, nil @@ -653,7 +654,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v } // computeTxEnv returns the execution environment of a certain transaction. -func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, reexec uint64) (core.Message, vm.Context, *state.StateDB, error) { +func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, reexec uint64) (*core.Message, vm.Context, *state.StateDB, error) { // Create the parent state database block := api.eth.blockchain.GetBlockByHash(blockHash) if block == nil { @@ -687,7 +688,7 @@ func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, ree balanceFee = value } } - msg, err := tx.AsMessage(types.MakeSigner(api.config, block.Header().Number), balanceFee, block.Number()) + msg, err := core.TransactionToMessage(tx, types.MakeSigner(api.config, block.Header().Number), balanceFee, block.Number()) if err != nil { return nil, vm.Context{}, nil, fmt.Errorf("tx %x failed: %v", tx.Hash(), err) } diff --git a/eth/backend.go b/eth/backend.go index 412c67d230..1fe58ca2da 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -18,6 +18,7 @@ package eth import ( + "bytes" "errors" "fmt" "math/big" @@ -27,18 +28,10 @@ import ( "sync/atomic" "time" - "github.com/tomochain/tomochain/tomoxlending" - - "github.com/tomochain/tomochain/accounts/abi/bind" - "github.com/tomochain/tomochain/common/hexutil" - "github.com/tomochain/tomochain/core/state" - "github.com/tomochain/tomochain/eth/filters" - "github.com/tomochain/tomochain/rlp" - - "bytes" - "github.com/tomochain/tomochain/accounts" + "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/consensus/posv" @@ -46,11 +39,12 @@ import ( contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" - - //"github.com/tomochain/tomochain/core/state" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/eth/downloader" + "github.com/tomochain/tomochain/eth/filters" "github.com/tomochain/tomochain/eth/gasprice" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -60,8 +54,10 @@ import ( "github.com/tomochain/tomochain/node" "github.com/tomochain/tomochain/p2p" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomoxlending" ) type LesServer interface { @@ -125,6 +121,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi if !config.SyncMode.IsValid() { return nil, fmt.Errorf("invalid sync mode %d", config.SyncMode) } + chainDb, err := CreateDB(ctx, config, "chaindata") if err != nil { return nil, err @@ -160,15 +157,20 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId) if !config.SkipBcVersionCheck { - bcVersion := core.GetBlockChainVersion(chainDb) + bcVersion := rawdb.GetBlockChainVersion(chainDb) if bcVersion != core.BlockChainVersion && bcVersion != 0 { return nil, fmt.Errorf("Blockchain DB version mismatch (%d / %d). Run geth upgradedb.\n", bcVersion, core.BlockChainVersion) } - core.WriteBlockChainVersion(chainDb, core.BlockChainVersion) + rawdb.WriteBlockChainVersion(chainDb, core.BlockChainVersion) } var ( vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording} - cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout} + cacheConfig = &core.CacheConfig{ + Disabled: config.NoPruning, + TrieNodeLimit: config.TrieCache, + TrieTimeLimit: config.TrieTimeout, + SnapshotLimit: config.SnapshotCache, + } ) if eth.chainConfig.Posv != nil { c := eth.engine.(*posv.Posv) @@ -187,7 +189,7 @@ func New(ctx *node.ServiceContext, config *Config, tomoXServ *tomox.TomoX, lendi if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) eth.blockchain.SetHead(compat.RewindTo) - core.WriteChainConfig(chainDb, genesisHash, chainConfig) + rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) } eth.bloomIndexer.Start(eth.blockchain) diff --git a/eth/bloombits.go b/eth/bloombits.go index abe8c5d671..39695f43e8 100644 --- a/eth/bloombits.go +++ b/eth/bloombits.go @@ -17,13 +17,13 @@ package eth import ( - "github.com/tomochain/tomochain/core/rawdb" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/bitutil" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/params" @@ -61,8 +61,8 @@ func (eth *Ethereum) startBloomHandlers() { task := <-request task.Bitsets = make([][]byte, len(task.Sections)) for i, section := range task.Sections { - head := core.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1) - if compVector, err := core.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil { + head := rawdb.GetCanonicalHash(eth.chainDb, (section+1)*params.BloomBitsBlocks-1) + if compVector, err := rawdb.GetBloomBits(eth.chainDb, task.Bit, section, head); err == nil { if blob, err := bitutil.DecompressBytes(compVector, int(params.BloomBitsBlocks)/8); err == nil { task.Bitsets[i] = blob } else { @@ -108,7 +108,7 @@ func NewBloomIndexer(db ethdb.Database, size uint64) *core.ChainIndexer { db: db, size: size, } - table := rawdb.NewTable(db, string(core.BloomBitsIndexPrefix)) + table := rawdb.NewTable(db, string(rawdb.BloomBitsIndexPrefix)) return core.NewChainIndexer(db, table, backend, size, bloomConfirms, bloomThrottling, "bloombits") } @@ -138,7 +138,7 @@ func (b *BloomIndexer) Commit() error { if err != nil { return err } - core.WriteBloomBits(batch, uint(i), b.section, b.head, bitutil.CompressBytes(bits)) + rawdb.WriteBloomBits(batch, uint(i), b.section, b.head, bitutil.CompressBytes(bits)) } return batch.Write() } diff --git a/eth/config.go b/eth/config.go index a86f084561..8b62ab7e48 100644 --- a/eth/config.go +++ b/eth/config.go @@ -48,6 +48,7 @@ var DefaultConfig = Config{ DatabaseCache: 768, TrieCache: 256, TrieTimeout: 5 * time.Minute, + SnapshotCache: 256, GasPrice: big.NewInt(0.25 * params.Shannon), TxPool: core.DefaultTxPoolConfig, @@ -93,6 +94,7 @@ type Config struct { DatabaseCache int TrieCache int TrieTimeout time.Duration + SnapshotCache int // Mining-related options Etherbase common.Address `toml:",omitempty"` diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index eba7fad779..f9faf2ff6d 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -27,7 +27,7 @@ import ( "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -225,7 +225,7 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockC stateCh: make(chan dataPack), stateSyncStart: make(chan *stateSync), syncStatsState: stateSyncStats{ - processed: core.GetTrieSyncProgress(stateDb), + processed: rawdb.GetTrieSyncProgress(stateDb), }, trackStateReq: make(chan *stateReq), } @@ -975,22 +975,22 @@ func (d *Downloader) fetchReceipts(from uint64) error { // various callbacks to handle the slight differences between processing them. // // The instrumentation parameters: -// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) -// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) -// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) -// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) -// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) -// - pending: task callback for the number of requests still needing download (detect completion/non-completability) -// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) -// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) -// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) -// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) -// - fetch: network callback to actually send a particular download request to a physical remote peer -// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) -// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) -// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks -// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) -// - kind: textual label of the type being downloaded to display in log mesages +// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) +// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) +// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) +// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) +// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) +// - pending: task callback for the number of requests still needing download (detect completion/non-completability) +// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) +// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) +// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) +// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) +// - fetch: network callback to actually send a particular download request to a physical remote peer +// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) +// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) +// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks +// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) +// - kind: textual label of the type being downloaded to display in log mesages func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error), fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int, diff --git a/eth/downloader/fakepeer.go b/eth/downloader/fakepeer.go index 4d7c5ac280..5858a05499 100644 --- a/eth/downloader/fakepeer.go +++ b/eth/downloader/fakepeer.go @@ -21,6 +21,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" ) @@ -126,7 +127,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error { uncles [][]*types.Header ) for _, hash := range hashes { - block := core.GetBlock(p.db, hash, p.hc.GetBlockNumber(hash)) + block := rawdb.GetBlock(p.db, hash, p.hc.GetBlockNumber(hash)) txs = append(txs, block.Transactions()) uncles = append(uncles, block.Uncles()) @@ -140,7 +141,7 @@ func (p *FakePeer) RequestBodies(hashes []common.Hash) error { func (p *FakePeer) RequestReceipts(hashes []common.Hash) error { var receipts [][]*types.Receipt for _, hash := range hashes { - receipts = append(receipts, core.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash))) + receipts = append(receipts, rawdb.GetBlockReceipts(p.db, hash, p.hc.GetBlockNumber(hash), p.hc.Config())) } p.dl.DeliverReceipts(p.id, receipts) return nil diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 0ed4e75faa..f01f54b67d 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -29,6 +29,7 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/metrics" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -391,7 +392,7 @@ func (q *queue) Results(block bool) []*fetchResult { size += receipt.Size() } for _, tx := range result.Transactions { - size += tx.Size() + size += common.StorageSize(tx.Size()) } q.resultSize = common.StorageSize(blockCacheSizeWeight)*size + (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize } @@ -767,7 +768,7 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Transactions(txLists[index])) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { + if types.DeriveSha(types.Transactions(txLists[index]), new(trie.StackTrie)) != header.TxHash || types.CalcUncleHash(uncleLists[index]) != header.UncleHash { return errInvalidBody } result.Transactions = txLists[index] @@ -785,7 +786,7 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, defer q.lock.Unlock() reconstruct := func(header *types.Header, index int, result *fetchResult) error { - if types.DeriveSha(types.Receipts(receiptList[index])) != header.ReceiptHash { + if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.StackTrie)) != header.ReceiptHash { return errInvalidReceipt } result.Receipts = receiptList[index] diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 3809a0c579..747c9f9cff 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -18,16 +18,16 @@ package downloader import ( "fmt" - "github.com/tomochain/tomochain/ethdb/memorydb" "hash" "sync" "time" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/ethdb/memorydb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/trie" ) @@ -470,6 +470,6 @@ func (s *stateSync) updateStats(written, duplicate, unexpected int, duration tim log.Info("Imported new state entries", "count", written, "elapsed", common.PrettyDuration(duration), "processed", s.d.syncStatsState.processed, "pending", s.d.syncStatsState.pending, "retry", len(s.tasks), "duplicate", s.d.syncStatsState.duplicate, "unexpected", s.d.syncStatsState.unexpected) } if written > 0 { - core.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed) + rawdb.WriteTrieSyncProgress(s.d.stateDB, s.d.syncStatsState.processed) } } diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index 65b15094d2..d1bc108fd2 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -19,14 +19,16 @@ package fetcher import ( "errors" - "github.com/hashicorp/golang-lru" "math/rand" "time" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -468,7 +470,7 @@ func (f *Fetcher) loop() { announce.time = task.time // If the block is empty (header only), short circuit into the final import queue - if header.TxHash == types.DeriveSha(types.Transactions{}) && header.UncleHash == types.CalcUncleHash([]*types.Header{}) { + if header.TxHash == types.EmptyRootHash && header.UncleHash == types.CalcUncleHash([]*types.Header{}) { log.Trace("Block empty, skipping body retrieval", "peer", announce.origin, "number", header.Number, "hash", header.Hash()) block := types.NewBlockWithHeader(header) @@ -530,7 +532,7 @@ func (f *Fetcher) loop() { for hash, announce := range f.completing { if f.queued[hash] == nil { - txnHash := types.DeriveSha(types.Transactions(task.transactions[i])) + txnHash := types.DeriveSha(types.Transactions(task.transactions[i]), new(trie.StackTrie)) uncleHash := types.CalcUncleHash(task.uncles[i]) if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash && announce.origin == task.peer { diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index ab7e03aaa1..951b2fcd6c 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -18,7 +18,6 @@ package fetcher import ( "errors" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "sync" "sync/atomic" @@ -28,9 +27,11 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/params" + "github.com/tomochain/tomochain/trie" ) var ( @@ -38,7 +39,7 @@ var ( testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testAddress = crypto.PubkeyToAddress(testKey.PublicKey) genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000)) - unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil) + unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil, new(trie.StackTrie)) ) // makeChain creates a chain of n blocks starting at and including parent. diff --git a/eth/filters/bench_test.go b/eth/filters/bench_test.go index 3648a3db2f..9822a85e43 100644 --- a/eth/filters/bench_test.go +++ b/eth/filters/bench_test.go @@ -20,14 +20,13 @@ import ( "bytes" "context" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "testing" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/bitutil" - "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -68,18 +67,18 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { benchDataDir := node.DefaultDataDir() + "/geth/chaindata" fmt.Println("Running bloombits benchmark section size:", sectionSize) - db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") if err != nil { b.Fatalf("error opening database at %v: %v", benchDataDir, err) } - head := core.GetHeadBlockHash(db) + head := rawdb.GetHeadBlockHash(db) if head == (common.Hash{}) { b.Fatalf("chain data not found at %v", benchDataDir) } clearBloomBits(db) fmt.Println("Generating bloombits data...") - headNum := core.GetBlockNumber(db, head) + headNum := rawdb.GetBlockNumber(db, head) if headNum < sectionSize+512 { b.Fatalf("not enough blocks for running a benchmark") } @@ -94,14 +93,14 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { } var header *types.Header for i := sectionIdx * sectionSize; i < (sectionIdx+1)*sectionSize; i++ { - hash := core.GetCanonicalHash(db, i) - header = core.GetHeader(db, hash, i) + hash := rawdb.GetCanonicalHash(db, i) + header = rawdb.GetHeader(db, hash, i) if header == nil { b.Fatalf("Error creating bloomBits data") } bc.AddBloom(uint(i-sectionIdx*sectionSize), header.Bloom) } - sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*sectionSize-1) + sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*sectionSize-1) for i := 0; i < types.BloomBitLength; i++ { data, err := bc.Bitset(uint(i)) if err != nil { @@ -110,7 +109,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { comp := bitutil.CompressBytes(data) dataSize += uint64(len(data)) compSize += uint64(len(comp)) - core.WriteBloomBits(db, uint(i), sectionIdx, sectionHead, comp) + rawdb.WriteBloomBits(db, uint(i), sectionIdx, sectionHead, comp) } //if sectionIdx%50 == 0 { // fmt.Println(" section", sectionIdx, "/", cnt) @@ -130,7 +129,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { for i := 0; i < benchFilterCnt; i++ { if i%20 == 0 { db.Close() - db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, _ = rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") backend = &testBackend{mux, db, cnt, new(event.Feed), new(event.Feed), new(event.Feed), new(event.Feed)} } var addr common.Address @@ -148,7 +147,7 @@ func benchmarkBloomBits(b *testing.B, sectionSize uint64) { } func forEachKey(db ethdb.Database, startPrefix, endPrefix []byte, fn func(key []byte)) { - it := db.NewIterator(startPrefix,nil) + it := db.NewIterator(startPrefix, nil) for it.Next() { key := it.Key() cmpLen := len(key) @@ -176,15 +175,15 @@ func clearBloomBits(db ethdb.Database) { func BenchmarkNoBloomBits(b *testing.B) { benchDataDir := node.DefaultDataDir() + "/geth/chaindata" fmt.Println("Running benchmark without bloombits") - db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024,"") + db, err := rawdb.NewLevelDBDatabase(benchDataDir, 128, 1024, "") if err != nil { b.Fatalf("error opening database at %v: %v", benchDataDir, err) } - head := core.GetHeadBlockHash(db) + head := rawdb.GetHeadBlockHash(db) if head == (common.Hash{}) { b.Fatalf("chain data not found at %v", benchDataDir) } - headNum := core.GetBlockNumber(db, head) + headNum := rawdb.GetBlockNumber(db, head) clearBloomBits(db) diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index 3d92fc1ac7..75c3c5e417 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -28,6 +28,7 @@ import ( ethereum "github.com/tomochain/tomochain" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/rpc" @@ -348,11 +349,11 @@ func (es *EventSystem) lightFilterNewHead(newHeader *types.Header, callBack func for oldh.Hash() != newh.Hash() { if oldh.Number.Uint64() >= newh.Number.Uint64() { oldHeaders = append(oldHeaders, oldh) - oldh = core.GetHeader(es.backend.ChainDb(), oldh.ParentHash, oldh.Number.Uint64()-1) + oldh = rawdb.GetHeader(es.backend.ChainDb(), oldh.ParentHash, oldh.Number.Uint64()-1) } if oldh.Number.Uint64() < newh.Number.Uint64() { newHeaders = append(newHeaders, newh) - newh = core.GetHeader(es.backend.ChainDb(), newh.ParentHash, newh.Number.Uint64()-1) + newh = rawdb.GetHeader(es.backend.ChainDb(), newh.ParentHash, newh.Number.Uint64()-1) if newh == nil { // happens when CHT syncing, nothing to do newh = oldh diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index d947a672ac..077a9c41bf 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -19,7 +19,6 @@ package filters import ( "context" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "reflect" @@ -31,6 +30,7 @@ import ( "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/event" @@ -48,6 +48,10 @@ type testBackend struct { chainFeed *event.Feed } +func (b *testBackend) ChainConfig() *params.ChainConfig { + return params.TestChainConfig +} + func (b *testBackend) ChainDb() ethdb.Database { return b.db } @@ -60,23 +64,23 @@ func (b *testBackend) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumbe var hash common.Hash var num uint64 if blockNr == rpc.LatestBlockNumber { - hash = core.GetHeadBlockHash(b.db) - num = core.GetBlockNumber(b.db, hash) + hash = rawdb.GetHeadBlockHash(b.db) + num = rawdb.GetBlockNumber(b.db, hash) } else { num = uint64(blockNr) - hash = core.GetCanonicalHash(b.db, num) + hash = rawdb.GetCanonicalHash(b.db, num) } - return core.GetHeader(b.db, hash, num), nil + return rawdb.GetHeader(b.db, hash, num), nil } func (b *testBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - number := core.GetBlockNumber(b.db, blockHash) - return core.GetBlockReceipts(b.db, blockHash, number), nil + number := rawdb.GetBlockNumber(b.db, blockHash) + return rawdb.GetBlockReceipts(b.db, blockHash, number, b.ChainConfig()), nil } func (b *testBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - number := core.GetBlockNumber(b.db, blockHash) - receipts := core.GetBlockReceipts(b.db, blockHash, number) + number := rawdb.GetBlockNumber(b.db, blockHash) + receipts := rawdb.GetBlockReceipts(b.db, blockHash, number, b.ChainConfig()) logs := make([][]*types.Log, len(receipts)) for i, receipt := range receipts { @@ -122,8 +126,8 @@ func (b *testBackend) ServiceFilter(ctx context.Context, session *bloombits.Matc task.Bitsets = make([][]byte, len(task.Sections)) for i, section := range task.Sections { if rand.Int()%4 != 0 { // Handle occasional missing deliveries - head := core.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1) - task.Bitsets[i], _ = core.GetBloomBits(b.db, task.Bit, section, head) + head := rawdb.GetCanonicalHash(b.db, (section+1)*params.BloomBitsBlocks-1) + task.Bitsets[i], _ = rawdb.GetBloomBits(b.db, task.Bit, section, head) } } request <- task diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go index bdfb6e37f8..a5ddb00dbe 100644 --- a/eth/filters/filter_test.go +++ b/eth/filters/filter_test.go @@ -18,7 +18,6 @@ package filters import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "os" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/event" @@ -50,7 +50,7 @@ func BenchmarkFilters(b *testing.B) { defer os.RemoveAll(dir) var ( - db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"") + db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "") mux = new(event.TypeMux) txFeed = new(event.Feed) rmLogsFeed = new(event.Feed) @@ -84,14 +84,14 @@ func BenchmarkFilters(b *testing.B) { } }) for i, block := range chain { - core.WriteBlock(db, block) - if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { + rawdb.WriteBlock(db, block) + if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { b.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil { b.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { + if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { b.Fatal("error writing block receipts:", err) } } @@ -115,7 +115,7 @@ func TestFilters(t *testing.T) { defer os.RemoveAll(dir) var ( - db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0,"") + db, _ = rawdb.NewLevelDBDatabase(dir, 0, 0, "") mux = new(event.TypeMux) txFeed = new(event.Feed) rmLogsFeed = new(event.Feed) @@ -144,6 +144,7 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) case 2: receipt := types.NewReceipt(nil, false, 0) receipt.Logs = []*types.Log{ @@ -153,6 +154,7 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) case 998: receipt := types.NewReceipt(nil, false, 0) receipt.Logs = []*types.Log{ @@ -162,6 +164,7 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) case 999: receipt := types.NewReceipt(nil, false, 0) receipt.Logs = []*types.Log{ @@ -171,17 +174,19 @@ func TestFilters(t *testing.T) { }, } gen.AddUncheckedReceipt(receipt) + gen.AddUncheckedTx(types.NewTransaction(999, common.HexToAddress("0x999"), big.NewInt(999), 999, nil, nil)) } }) + for i, block := range chain { - core.WriteBlock(db, block) - if err := core.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { + rawdb.WriteBlock(db, block) + if err := rawdb.WriteCanonicalHash(db, block.Hash(), block.NumberU64()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteHeadBlockHash(db, block.Hash()); err != nil { + if err := rawdb.WriteHeadBlockHash(db, block.Hash()); err != nil { t.Fatalf("failed to insert block number: %v", err) } - if err := core.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { + if err := rawdb.WriteBlockReceipts(db, block.Hash(), block.NumberU64(), receipts[i]); err != nil { t.Fatal("error writing block receipts:", err) } } diff --git a/eth/handler_test.go b/eth/handler_test.go index d8d2f00979..bee29ea90e 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -17,13 +17,14 @@ package eth import ( - "github.com/tomochain/tomochain/core/rawdb" "math" "math/big" "math/rand" "testing" "time" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" @@ -343,9 +344,9 @@ func testGetNodeData(t *testing.T, protocol int) { // Fetch for now the entire chain db hashes := []common.Hash{} - it:=db.NewIterator(nil,nil) + it := db.NewIterator(nil, nil) for it.Next() { - key:=it.Key() + key := it.Key() if len(key) == len(common.Hash{}) { hashes = append(hashes, common.BytesToHash(key)) } @@ -374,7 +375,7 @@ func testGetNodeData(t *testing.T, protocol int) { } accounts := []common.Address{testBank, acc1Addr, acc2Addr} for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { - trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb)) + trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil) for j, acc := range accounts { state, _ := pm.blockchain.State() @@ -470,7 +471,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool var ( evmux = new(event.TypeMux) pow = ethash.NewFaker() - db = rawdb.NewMemoryDatabase() + db = rawdb.NewMemoryDatabase() config = ¶ms.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked} gspec = &core.Genesis{Config: config} genesis = gspec.MustCommit(db) diff --git a/eth/sync.go b/eth/sync.go index a1224b3caf..8cf530b195 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -78,7 +78,7 @@ func (pm *ProtocolManager) txsyncLoop() { pack.txs = pack.txs[:0] for i := 0; i < len(s.txs) && size < txsyncPackSize; i++ { pack.txs = append(pack.txs, s.txs[i]) - size += s.txs[i].Size() + size += common.StorageSize(s.txs[i].Size()) } // Remove the transactions that will be sent. s.txs = s.txs[:copy(s.txs, s.txs[len(pack.txs):])] diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index 38d4075175..2764f7034a 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -20,17 +20,18 @@ import ( "crypto/ecdsa" "crypto/rand" "encoding/json" - "github.com/tomochain/tomochain/core/rawdb" "io/ioutil" "math/big" "path/filepath" "reflect" "strings" "testing" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" @@ -169,21 +170,21 @@ func TestPrestateTracerCreate2(t *testing.T) { Balance: big.NewInt(500000000000000), } db := rawdb.NewMemoryDatabase() - statedb := tests.MakePreState(db, alloc) + statedb := tests.MakePreState(db, alloc, false) // Create the tracer, the EVM environment and run it tracer, err := New("prestateTracer") if err != nil { t.Fatalf("failed to create call tracer: %v", err) } - evm := vm.NewEVM(context, statedb, nil, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer}) + evm := vm.NewEVM(context, statedb, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) - msg, err := tx.AsMessage(signer, nil, nil) + msg, err := core.TransactionToMessage(tx, signer, nil, nil) if err != nil { t.Fatalf("failed to prepare transaction for tracing: %v", err) } st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) - if _, _, _, err = st.TransitionDb(common.Address{}); err != nil { + if _, err = st.TransitionDb(common.Address{}); err != nil { t.Fatalf("failed to execute transaction: %v", err) } // Retrieve the trace result and compare against the etalon @@ -244,7 +245,7 @@ func TestCallTracer(t *testing.T) { GasPrice: tx.GasPrice(), } db := rawdb.NewMemoryDatabase() - statedb := tests.MakePreState(db, test.Genesis.Alloc) + statedb := tests.MakePreState(db, test.Genesis.Alloc, false) // Create the tracer, the EVM environment and run it tracer, err := New("callTracer") @@ -253,12 +254,12 @@ func TestCallTracer(t *testing.T) { } evm := vm.NewEVM(context, statedb, nil, test.Genesis.Config, vm.Config{Debug: true, Tracer: tracer}) - msg, err := tx.AsMessage(signer, nil, common.Big0) + msg, err := core.TransactionToMessage(tx, signer, nil, common.Big0) if err != nil { t.Fatalf("failed to prepare transaction for tracing: %v", err) } st := core.NewStateTransition(evm, msg, new(core.GasPool).AddGas(tx.Gas())) - if _, _, _, err = st.TransitionDb(common.Address{}); err != nil { + if _, err = st.TransitionDb(common.Address{}); err != nil { t.Fatalf("failed to execute transaction: %v", err) } // Retrieve the trace result and compare against the etalon diff --git a/ethclient/signer.go b/ethclient/signer.go index 664fdc6693..ae988f2f90 100644 --- a/ethclient/signer.go +++ b/ethclient/signer.go @@ -51,6 +51,9 @@ func (s *senderFromServer) Sender(tx *types.Transaction) (common.Address, error) return s.addr, nil } +func (s *senderFromServer) ChainID() *big.Int { + panic("can't sign with senderFromServer") +} func (s *senderFromServer) Hash(tx *types.Transaction) common.Hash { panic("can't sign with senderFromServer") } diff --git a/go.mod b/go.mod index 15d820f802..6c16905987 100644 --- a/go.mod +++ b/go.mod @@ -4,44 +4,45 @@ go 1.19 require ( bazil.org/fuse v0.0.0-20180421153158-65cc252bf669 - github.com/VictoriaMetrics/fastcache v1.5.7 + github.com/VictoriaMetrics/fastcache v1.6.0 github.com/aristanetworks/goarista v0.0.0-20191023202215-f096da5361bb github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6 github.com/cespare/cp v1.1.1 github.com/davecgh/go-spew v1.1.1 github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea - github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf + github.com/docker/docker v1.6.2 github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7 github.com/edsrzf/mmap-go v1.0.0 - github.com/fatih/color v1.6.0 + github.com/fatih/color v1.7.0 github.com/gizak/termui v2.2.0+incompatible github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 - github.com/go-stack/stack v1.8.0 - github.com/golang/protobuf v1.3.2 - github.com/golang/snappy v0.0.1 + github.com/go-stack/stack v1.8.1 + github.com/golang/protobuf v1.5.2 + github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb github.com/hashicorp/golang-lru v0.5.3 - github.com/huin/goupnp v1.0.0 + github.com/holiman/uint256 v1.2.2 + github.com/huin/goupnp v1.0.3 github.com/influxdata/influxdb v1.7.9 - github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 + github.com/jackpal/go-nat-pmp v1.0.2 github.com/julienschmidt/httprouter v1.3.0 github.com/karalabe/hid v1.0.0 - github.com/mattn/go-colorable v0.1.0 + github.com/mattn/go-colorable v0.1.13 github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 - github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c + github.com/olekukonko/tablewriter v0.0.5 github.com/pborman/uuid v1.2.0 github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 - github.com/pkg/errors v0.8.1 + github.com/pkg/errors v0.9.1 github.com/prometheus/prometheus v1.7.2-0.20170814170113-3101606756c5 github.com/rjeczalik/notify v0.9.2 - github.com/rs/cors v1.6.0 + github.com/rs/cors v1.7.0 github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570 - github.com/stretchr/testify v1.4.0 - github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20220722155237-a158d28d115b - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 - golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f - golang.org/x/tools v0.1.12 + github.com/stretchr/testify v1.8.1 + github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 + golang.org/x/crypto v0.1.0 + golang.org/x/net v0.8.0 + golang.org/x/sync v0.1.0 + golang.org/x/sys v0.7.0 + golang.org/x/tools v0.7.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951 gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce @@ -50,27 +51,30 @@ require ( ) require ( - github.com/cespare/xxhash/v2 v2.1.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dlclark/regexp2 v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect - github.com/google/go-cmp v0.3.1 // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect - github.com/google/uuid v1.0.0 // indirect - github.com/kr/pretty v0.3.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/holiman/bloomfilter/v2 v2.0.3 + github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e // indirect github.com/maruel/ut v1.0.2 // indirect - github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035 // indirect - github.com/mattn/go-runewidth v0.0.4 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect github.com/naoina/go-stringutil v0.1.0 // indirect github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.6.1 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 // indirect - golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect - golang.org/x/text v0.3.8 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect - gotest.tools v2.2.0+incompatible // indirect + golang.org/x/mod v0.11.0 // indirect + golang.org/x/term v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect + google.golang.org/protobuf v1.28.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2fff78c90b..693c5630c6 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Shopify/sarama v1.23.1/go.mod h1:XLH1GYJnLVE0XCr6KdJGVJRTwY30moWNJ4sERjXX6fs= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/VictoriaMetrics/fastcache v1.5.7 h1:4y6y0G8PRzszQUYIQHHssv/jgPHAb5qQuuDNdCbyAgw= -github.com/VictoriaMetrics/fastcache v1.5.7/go.mod h1:ptDBkNMQI4RtmVo8VS/XwRY6RoTu1dAWCbrk+6WsEM8= +github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= +github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8= @@ -23,8 +23,9 @@ github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6 h1:Eey/GGQ/E5Xp1P2Ly github.com/btcsuite/btcd v0.0.0-20171128150713-2e60448ffcc6/go.mod h1:Dmm/EzmjnCiweXmzRIAiUWCInVmPgjkzgv5k4tVyXiQ= github.com/cespare/cp v1.1.1 h1:nCb6ZLdB7NRaqsm91JtQTAme2SKJzXVsdPIPkyJr1MU= github.com/cespare/cp v1.1.1/go.mod h1:SOGHArjBr4JWaSDEVpWpo/hNg6RoKrls6Oh40hiwW+s= -github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -38,8 +39,8 @@ github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea/go.mod h1:93vs github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo= github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf h1:sh8rkQZavChcmakYiSlqu2425CHyFXLZZnvm7PDpU8M= -github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v1.6.2 h1:HlFGsy+9/xrgMmhmN+NGhCc5SHGJ7I+kHosRR1xc/aI= +github.com/docker/docker v1.6.2/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/dop251/goja v0.0.0-20211022113120-dc8c55024d06/go.mod h1:R9ET47fwRVRPZnOGvHxxhuZcbrMCuiqOz3Rlrh4KSnk= github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7 h1:cVGkvrdHgyBkYeB6kMCaF5j2d9Bg4trgbIpcUrKrvk4= github.com/dop251/goja v0.0.0-20230531210528-d7324b2d74f7/go.mod h1:QMWlm50DNe14hD7t24KEqZuUdC9sOTy8W6XbCU1mlw4= @@ -50,9 +51,12 @@ github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1 github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/fatih/color v1.6.0 h1:66qjqZk8kalYAvDRtM1AdAJQI0tj4Wrue3Eq3B3pmFU= -github.com/fatih/color v1.6.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/garyburd/redigo v1.6.0/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/gizak/termui v2.2.0+incompatible h1:qvZU9Xll/Xd/Xr/YO+HfBKXhy8a8/94ao6vV9DSXzUE= github.com/gizak/termui v2.2.0+incompatible/go.mod h1:PkJoWUt/zacQKysNfQtcw1RW+eK2SxkieVBtl+4ovLA= @@ -63,40 +67,59 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= +github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk= github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= -github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= +github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA= +github.com/holiman/uint256 v1.2.2 h1:TXKcSGc2WaxPD2+bmzAsVthL4+pEN0YwXcL5qED83vk= +github.com/holiman/uint256 v1.2.2/go.mod h1:SC8Ryt4n+UBbPbIBKaG9zbbDlp4jOru9xFZmPzLUTxw= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo= -github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= +github.com/huin/goupnp v1.0.3 h1:N8No57ls+MnjlB+JPiCVSOyy/ot7MJTqlo7rn+NYSqQ= +github.com/huin/goupnp v1.0.3/go.mod h1:ZxNlw5WqJj6wSsRK5+YfflQGXYfccj5VgQsMNixHM7Y= github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o= github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/influxdata/influxdb v1.7.9 h1:uSeBTNO4rBkbp1Be5FKRsAmglM9nlx25TzVQRQt1An4= github.com/influxdata/influxdb v1.7.9/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/influxdata/influxdb1-client v0.0.0-20190809212627-fc22c7df067e/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= -github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 h1:6OvNmYgJyexcZ3pYbTI9jWx5tHo1Dee/tWbLMfPe2TA= -github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= +github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -111,8 +134,9 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -123,13 +147,13 @@ github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e h1:e2z/lz9pvtRrE github.com/maruel/panicparse v0.0.0-20160720141634-ad661195ed0e/go.mod h1:nty42YY5QByNC5MM7q/nj938VbgPU7avs45z6NClpxI= github.com/maruel/ut v1.0.2 h1:mQTlQk3jubTbdTcza+hwoZQWhzcvE4L6K6RTtAFlA1k= github.com/maruel/ut v1.0.2/go.mod h1:RV8PwPD9dd2KFlnlCc/DB2JVvkXmyaalfc5xvmSrRSs= -github.com/mattn/go-colorable v0.1.0 h1:v2XXALHHh6zHfYTJ+cSkwtyffnaOyR1MXaA91mTrb8o= -github.com/mattn/go-colorable v0.1.0/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035 h1:USWjF42jDCSEeikX/G1g40ZWnsPXN5WkZ4jMHZWyBK4= -github.com/mattn/go-isatty v0.0.5-0.20180830101745-3fb116b82035/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= -github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM= github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= @@ -144,15 +168,19 @@ github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 h1:shk/vn9oCoOTmwcou github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E= github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77 h1:gKl78uP/I7JZ56OFtRf7nc4m1icV38hwV0In5pEGzeA= github.com/nsf/termbox-go v0.0.0-20170211012700-3540b76b9c77/go.mod h1:IuKpRQcYE1Tfu+oAQqaLisqDeXgjyyltCfsaoYN18NQ= -github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c h1:1RHs3tNxjXGHeul8z2t6H2N2TlAqpKe5yryJztRx4Jk= -github.com/olekukonko/tablewriter v0.0.2-0.20190409134802-7e037d187b0c/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo= github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0 h1:2mOpI4JVVPBN+WQRa0WKH2eXR+Ey+uK4n7Zj0aYpIQA= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/openconfig/gnmi v0.0.0-20190823184014-89b2bf29312c/go.mod h1:t+O9It+LKzfOAhKTT5O0ehDix+MTqbtT0T9t+7zzOvc= github.com/openconfig/reference v0.0.0-20190727015836-8dfd928c9696/go.mod h1:ym2A+zigScwkSEb/cVQB0/ZMpU3rqiH6X7WRRsxgOGw= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= @@ -160,9 +188,10 @@ github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtP github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7 h1:oYW+YCJ1pachXTQmzR3rNLYGGz4g/UgFcjb28p/viDM= github.com/peterh/liner v1.1.1-0.20190123174540-a2c9a5303de7/go.mod h1:CRroGNssyjTd/qIG2FyxByd2S8JEAZXBl4qUrZf8GS0= github.com/pierrec/lz4 v0.0.0-20190327172049-315a67e90e41/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -181,10 +210,11 @@ github.com/prometheus/prometheus v1.7.2-0.20170814170113-3101606756c5/go.mod h1: github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rjeczalik/notify v0.9.2 h1:MiTWrPj55mNDHEiIX5YUSKefw/+lCQVoAFmD6oQm5w8= github.com/rjeczalik/notify v0.9.2/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM= -github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rs/cors v1.6.0 h1:G9tHG9lebljV9mfp9SNPDL36nCDxmo3zTlAf1YgvzmI= -github.com/rs/cors v1.6.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= +github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/steakknife/bloomfilter v0.0.0-20180922174646-6819c0d2a570 h1:gIlAHnH1vJb5vwEjIp5kBj/eu99p/bl0Ay2goiPe5xE= @@ -193,12 +223,16 @@ github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3 h1:njlZPzLwU639 github.com/steakknife/hamming v0.0.0-20180906055917-c99c65617cd3/go.mod h1:hpGUWaI9xL8pRQCTXQgocU38Qw1g0Us7n5PxxTwTCYU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d h1:gZZadD8H+fF+n9CmNhYL1Y0dJB+kLOmKd7FbPJLeGHs= -github.com/syndtr/goleveldb v1.0.1-0.20190923125748-758128399b1d/go.mod h1:9OrXJhf154huy1nPWmuSrkgjPUtUNhA+Zmy+6AESzuA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= +github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU= github.com/templexxx/xor v0.0.0-20181023030647-4e92f724b73b/go.mod h1:5XA7W9S6mni3h5uvOC75dA3m9CCCaS83lltmc0ukdi4= github.com/tjfoc/gmsm v1.0.1/go.mod h1:XxO4hdhhrzAd+G4CjDqaOkd0hUzmtPR/d3EiBBMn/wc= @@ -210,64 +244,100 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190404164418-38d8ce5564a5/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190912160710-24e19bdeb0f2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190801041406-cbf593c0f2f3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190912141932-bc967efca4b8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190912185636-87d9f09c5d89/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= +golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df h1:5Pf6pFKu98ODmgnpvkJ3kFUOQGGLIzLIkbzUHp47618= +golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/bsm/ratelimit.v1 v1.0.0-20160220154919-db14e161995a/go.mod h1:KF9sEfUPAXdG8Oev9e99iLGnl2uJMjc5B+4y3O7x610= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -275,7 +345,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= @@ -295,8 +364,11 @@ gopkg.in/urfave/cli.v1 v1.20.0 h1:NdAVW6RYxDif9DhDHaAortIu956m2c0v+09AZBPTbE0= gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= -gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/interfaces.go b/interfaces.go index 467b0fba56..80c4aba00a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -119,6 +119,8 @@ type CallMsg struct { GasPrice *big.Int // wei <-> gas exchange ratio Value *big.Int // amount of wei sent along with the call Data []byte // input data, usually an ABI-encoded contract method invocation + PmAddress common.Address + PmPayload []byte BalanceTokenFee *big.Int } diff --git a/internal/blocktest/test_hash.go b/internal/blocktest/test_hash.go new file mode 100644 index 0000000000..37d979e319 --- /dev/null +++ b/internal/blocktest/test_hash.go @@ -0,0 +1,60 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package utesting provides a standalone replacement for package testing. +// +// This package exists because package testing cannot easily be embedded into a +// standalone go program. It provides an API that mirrors the standard library +// testing API. + +package blocktest + +import ( + "hash" + + "golang.org/x/crypto/sha3" + + "github.com/tomochain/tomochain/common" +) + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +// NewHasher returns a new testHasher instance. +func NewHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +// Reset resets the hash state. +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +// Update updates the hash state with the given key and value. +func (h *testHasher) Update(key, val []byte) error { + h.hasher.Write(key) + h.hasher.Write(val) + return nil +} + +// Hash returns the hash value. +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 33376a1071..e5e75b60d9 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -21,17 +21,15 @@ import ( "context" "errors" "fmt" - "github.com/tomochain/tomochain/tomoxlending/lendingstate" "math/big" "sort" "strings" "time" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/util" "github.com/tomochain/tomochain/accounts" + "github.com/tomochain/tomochain/accounts/abi" "github.com/tomochain/tomochain/accounts/abi/bind" "github.com/tomochain/tomochain/accounts/keystore" "github.com/tomochain/tomochain/common" @@ -41,6 +39,7 @@ import ( "github.com/tomochain/tomochain/consensus/posv" contractValidator "github.com/tomochain/tomochain/contracts/validator/contract" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -50,6 +49,8 @@ import ( "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending/lendingstate" ) const ( @@ -424,7 +425,8 @@ func (s *PrivateAccountAPI) SignTransaction(ctx context.Context, args SendTxArgs // safely used to calculate a signature from. // // The hash is calulcated as -// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). // // This gives context to the signed message and prevents signing of transactions. func signHash(data []byte) []byte { @@ -1023,14 +1025,17 @@ type CallArgs struct { GasPrice hexutil.Big `json:"gasPrice"` Value hexutil.Big `json:"value"` Data hexutil.Bytes `json:"data"` + + // Introduced by PaymasterTxType transaction + PmPayload hexutil.Bytes `json:"pmPayload,omitempty"` } -func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error) { +func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, vmCfg vm.Config, timeout time.Duration) (*core.ExecutionResult, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) statedb, header, err := s.b.StateAndHeaderByNumber(ctx, blockNr) if statedb == nil || err != nil { - return nil, 0, false, err + return nil, err } // Set sender address or use a default if none specified addr := args.From @@ -1052,7 +1057,18 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr balanceTokenFee := big.NewInt(0).SetUint64(gas) balanceTokenFee = balanceTokenFee.Mul(balanceTokenFee, gasPrice) // Create new call message - msg := types.NewMessage(addr, args.To, 0, args.Value.ToInt(), gas, gasPrice, args.Data, false, balanceTokenFee) + msg := &core.Message{ + To: args.To, + From: addr, + Nonce: 0, + Value: args.Value.ToInt(), + GasLimit: gas, + GasPrice: gasPrice, + Data: args.Data, + PmPayload: args.PmPayload, + BalanceTokenFee: balanceTokenFee, + SkipAccountChecks: true, + } // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. @@ -1068,20 +1084,20 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr block, err := s.b.BlockByNumber(ctx, blockNr) if err != nil { - return nil, 0, false, err + return nil, err } author, err := s.b.GetEngine().Author(block.Header()) if err != nil { - return nil, 0, false, err + return nil, err } tomoxState, err := s.b.TomoxService().GetTradingState(block, author) if err != nil { - return nil, 0, false, err + return nil, err } // Get a new instance of the EVM. evm, vmError, err := s.b.GetEVM(ctx, msg, statedb, tomoxState, header, vmCfg) if err != nil { - return nil, 0, false, err + return nil, err } // Wait for the context to be done and cancel the evm. Even if the // EVM has finished, cancelling may be done (repeatedly) @@ -1094,18 +1110,55 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // and apply the message. gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - res, gas, failed, err := core.ApplyMessage(evm, msg, gp, owner) + result, err := core.ApplyMessage(evm, msg, gp, owner) if err := vmError(); err != nil { - return nil, 0, false, err + return nil, err + } + return result, err +} + +func newRevertError(result *core.ExecutionResult) *revertError { + reason, errUnpack := abi.UnpackRevert(result.Revert()) + err := errors.New("execution reverted") + if errUnpack == nil { + err = fmt.Errorf("execution reverted: %v", reason) } - return res, gas, failed, err + return &revertError{ + error: err, + reason: hexutil.Encode(result.Revert()), + } +} + +// revertError is an API error that encompassas an EVM revertal with JSON error +// code and a binary data blob. +type revertError struct { + error + reason string // revert reason hex encoded +} + +// ErrorCode returns the JSON error code for a revertal. +// See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal +func (e *revertError) ErrorCode() int { + return 3 +} + +// ErrorData returns the hex encoded revert reason. +func (e *revertError) ErrorData() interface{} { + return e.reason } // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { - result, _, _, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) - return (hexutil.Bytes)(result), err + result, err := s.doCall(ctx, args, blockNr, vm.Config{}, 5*time.Second) + if err != nil { + return nil, err + } + + if len(result.Revert()) > 0 { + return nil, newRevertError(result) + } + return result.Return(), result.Err } // EstimateGas returns an estimate of the amount of gas needed to execute the @@ -1130,19 +1183,26 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h cap = hi // Create a helper to check if a gas allowance results in an executable transaction - executable := func(gas uint64) bool { + executable := func(gas uint64) (bool, *core.ExecutionResult, error) { args.Gas = hexutil.Uint64(gas) - _, _, failed, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) - if err != nil || failed { - return false + result, err := s.doCall(ctx, args, rpc.LatestBlockNumber, vm.Config{}, 0) + if err != nil { + if err == core.ErrIntrinsicGas { + return true, nil, nil // Special case, raise gas limit + } + return true, nil, err } - return true + return result.Failed(), result, nil } // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - if !executable(mid) { + failed, _, err := executable(mid) + if err != nil { + return 0, err + } + if failed { lo = mid } else { hi = mid @@ -1150,8 +1210,19 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (h } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - if !executable(hi) { - return 0, fmt.Errorf("gas required exceeds allowance or always failing transaction") + failed, result, err := executable(hi) + if err != nil { + return 0, nil + } + + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return 0, newRevertError(result) + } + return 0, result.Err + } + return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) } } return hexutil.Uint64(hi), nil @@ -1305,8 +1376,8 @@ func (s *PublicBlockChainAPI) findNearestSignedBlock(ctx context.Context, b *typ } /* - findFinalityOfBlock return finality of a block - Use blocksHashCache for to keep track - refer core/blockchain.go for more detail +findFinalityOfBlock return finality of a block +Use blocksHashCache for to keep track - refer core/blockchain.go for more detail */ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types.Block, masternodes []common.Address) (uint, error) { engine, _ := s.b.GetEngine().(*posv.Posv) @@ -1371,7 +1442,7 @@ func (s *PublicBlockChainAPI) findFinalityOfBlock(ctx context.Context, b *types. } /* - Extract signers from block +Extract signers from block */ func (s *PublicBlockChainAPI) getSigners(ctx context.Context, block *types.Block, engine *posv.Posv) ([]common.Address, error) { var err error @@ -1594,7 +1665,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr // GetTransactionByHash returns the transaction for the given hash func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, hash common.Hash) *RPCTransaction { // Try to return an already finalized transaction - if tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash); tx != nil { + if tx, blockHash, blockNumber, index := rawdb.GetTransaction(s.b.ChainDb(), hash); tx != nil { return newRPCTransaction(tx, blockHash, blockNumber, index) } // No finalized transaction, try to retrieve it from the pool @@ -1610,7 +1681,7 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, var tx *types.Transaction // Retrieve a finalized transaction, or a pooled otherwise - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { // Transaction not found anywhere, abort return nil, nil @@ -1622,7 +1693,7 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, // GetTransactionReceipt returns the transaction receipt for the given transaction hash. func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { - tx, blockHash, blockNumber, index := core.GetTransaction(s.b.ChainDb(), hash) + tx, blockHash, blockNumber, index := rawdb.GetTransaction(s.b.ChainDb(), hash) if tx == nil { return nil, nil } @@ -1837,7 +1908,7 @@ func (s *PublicTransactionPoolAPI) SendTransaction(ctx context.Context, args Sen // The sender is responsible for signing the transaction and using the correct nonce. func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, encodedTx hexutil.Bytes) (common.Hash, error) { tx := new(types.Transaction) - if err := rlp.DecodeBytes(encodedTx, tx); err != nil { + if err := tx.UnmarshalBinary(encodedTx); err != nil { return common.Hash{}, err } return submitTransaction(ctx, s.b, tx) @@ -1867,7 +1938,7 @@ func (s *PublicTomoXTransactionPoolAPI) SendLendingRawTransaction(ctx context.Co func (s *PublicTomoXTransactionPoolAPI) GetOrderTxMatchByHash(ctx context.Context, hash common.Hash) ([]*tradingstate.OrderItem, error) { var tx *types.Transaction orders := []*tradingstate.OrderItem{} - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { return []*tradingstate.OrderItem{}, nil } @@ -2598,7 +2669,7 @@ func (s *PublicTomoXTransactionPoolAPI) GetBorrows(ctx context.Context, lendingT // GetLendingTxMatchByHash returns lendingItems which have been processed at tx of the given txhash func (s *PublicTomoXTransactionPoolAPI) GetLendingTxMatchByHash(ctx context.Context, hash common.Hash) ([]*lendingstate.LendingItem, error) { var tx *types.Transaction - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { return []*lendingstate.LendingItem{}, nil } @@ -2614,7 +2685,7 @@ func (s *PublicTomoXTransactionPoolAPI) GetLendingTxMatchByHash(ctx context.Cont // GetLiquidatedTradesByTxHash returns trades which closed by TomoX protocol at the tx of the give hash func (s *PublicTomoXTransactionPoolAPI) GetLiquidatedTradesByTxHash(ctx context.Context, hash common.Hash) (lendingstate.FinalizedResult, error) { var tx *types.Transaction - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + if tx, _, _, _ = rawdb.GetTransaction(s.b.ChainDb(), hash); tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { return lendingstate.FinalizedResult{}, nil } @@ -2965,7 +3036,8 @@ func GetSignersFromBlocks(b Backend, blockNumber uint64, blockHash common.Hash, // GetStakerROI Estimate ROI for stakers using the last epoc reward // then multiple by epoch per year, if the address is not masternode of last epoch - return 0 // Formular: -// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 +// +// ROI = average_latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 func (s *PublicBlockChainAPI) GetStakerROI() float64 { blockNumber := s.b.CurrentBlock().Number().Uint64() lastCheckpointNumber := blockNumber - (blockNumber % s.b.ChainConfig().Posv.Epoch) - s.b.ChainConfig().Posv.Epoch // calculate for 2 epochs ago @@ -2991,7 +3063,8 @@ func (s *PublicBlockChainAPI) GetStakerROI() float64 { // GetStakerROIMasternode Estimate ROI for stakers of a specific masternode using the last epoc reward // then multiple by epoch per year, if the address is not masternode of last epoch - return 0 // Formular: -// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 +// +// ROI = latest_epoch_reward_for_voters*number_of_epoch_per_year/latest_total_cap*100 func (s *PublicBlockChainAPI) GetStakerROIMasternode(masternode common.Address) float64 { votersReward := s.b.GetVotersRewards(masternode) if votersReward == nil { diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 16edc3a17f..9a197d4e2a 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -19,12 +19,8 @@ package ethapi import ( "context" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "math/big" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" @@ -38,6 +34,9 @@ import ( "github.com/tomochain/tomochain/event" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) // Backend interface provides the common API services (that are provided by @@ -61,7 +60,7 @@ type Backend interface { GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetTd(blockHash common.Hash) *big.Int - GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription SubscribeChainSideEvent(ch chan<- core.ChainSideEvent) event.Subscription diff --git a/internal/guide/guide_test.go b/internal/guide/guide_test.go index 8b82725811..1dfc794184 100644 --- a/internal/guide/guide_test.go +++ b/internal/guide/guide_test.go @@ -31,6 +31,7 @@ import ( "time" "github.com/tomochain/tomochain/accounts/keystore" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/types" ) @@ -75,7 +76,7 @@ func TestAccountManagement(t *testing.T) { if err != nil { t.Fatalf("Failed to create signer account: %v", err) } - tx, chain := new(types.Transaction), big.NewInt(1) + tx, chain := types.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), nil), big.NewInt(1) // Sign a transaction with a single authorization if _, err := ks.SignTxWithPassphrase(signer, "Signer password", tx, chain); err != nil { diff --git a/les/api_backend.go b/les/api_backend.go index d8285da97d..e49ee58c4a 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -20,20 +20,17 @@ import ( "context" "encoding/json" "errors" - "github.com/tomochain/tomochain/tomox/tradingstate" - "github.com/tomochain/tomochain/tomoxlending" "io/ioutil" "math/big" "path/filepath" - "github.com/tomochain/tomochain/tomox" - "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -45,6 +42,9 @@ import ( "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rpc" + "github.com/tomochain/tomochain/tomox" + "github.com/tomochain/tomochain/tomox/tradingstate" + "github.com/tomochain/tomochain/tomoxlending" ) type LesApiBackend struct { @@ -94,19 +94,19 @@ func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*t } func (b *LesApiBackend) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) { - return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + return light.GetBlockReceipts(ctx, b.eth.odr, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()) } func (b *LesApiBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) { - return light.GetBlockLogs(ctx, b.eth.odr, blockHash, core.GetBlockNumber(b.eth.chainDb, blockHash)) + return light.GetBlockLogs(ctx, b.eth.odr, blockHash, rawdb.GetBlockNumber(b.eth.chainDb, blockHash), b.ChainConfig()) } func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - state.SetBalance(msg.From(), math.MaxBig256) +func (b *LesApiBackend) GetEVM(ctx context.Context, msg *core.Message, state *state.StateDB, tomoxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { + state.SetBalance(msg.From, math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) return vm.NewEVM(context, state, tomoxState, b.eth.chainConfig, vmCfg), state.Error, nil } diff --git a/les/backend.go b/les/backend.go index 1a5cae11b8..9cebbd40e4 100644 --- a/les/backend.go +++ b/les/backend.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" "github.com/tomochain/tomochain/core/bloombits" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/eth/downloader" @@ -122,7 +123,7 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) { if compat, ok := genesisErr.(*params.ConfigCompatError); ok { log.Warn("Rewinding chain to upgrade configuration", "err", compat) leth.blockchain.SetHead(compat.RewindTo) - core.WriteChainConfig(chainDb, genesisHash, chainConfig) + rawdb.WriteChainConfig(chainDb, genesisHash, chainConfig) } leth.txPool = light.NewTxPool(leth.chainConfig, leth.blockchain, leth.relay) diff --git a/les/fetcher.go b/les/fetcher.go index 7edfe808bb..80568bc322 100644 --- a/les/fetcher.go +++ b/les/fetcher.go @@ -25,7 +25,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/mclock" "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/light" "github.com/tomochain/tomochain/log" @@ -280,7 +280,7 @@ func (f *lightFetcher) announce(p *peer, head *announceData) { // if one of root's children is canonical, keep it, delete other branches and root itself var newRoot *fetcherTreeNode for i, nn := range fp.root.children { - if core.GetCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { + if rawdb.GetCanonicalHash(f.pm.chainDb, nn.number) == nn.hash { fp.root.children = append(fp.root.children[:i], fp.root.children[i+1:]...) nn.parent = nil newRoot = nn @@ -363,7 +363,7 @@ func (f *lightFetcher) peerHasBlock(p *peer, hash common.Hash, number uint64) bo // // when syncing, just check if it is part of the known chain, there is nothing better we // can do since we do not know the most recent block hash yet - return core.GetCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && core.GetCanonicalHash(f.pm.chainDb, number) == hash + return rawdb.GetCanonicalHash(f.pm.chainDb, fp.root.number) == fp.root.hash && rawdb.GetCanonicalHash(f.pm.chainDb, number) == hash } // requestAmount calculates the amount of headers to be downloaded starting diff --git a/les/handler.go b/les/handler.go index b426f7fdd1..b6d7e601ac 100644 --- a/les/handler.go +++ b/les/handler.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "net" "sync" @@ -30,6 +29,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth/downloader" @@ -529,7 +529,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { break } // Retrieve the requested block body, stopping if enough was found - if data := core.GetBodyRLP(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 { + if data := rawdb.GetBodyRLP(pm.chainDb, hash, rawdb.GetBlockNumber(pm.chainDb, hash)); len(data) != 0 { bodies = append(bodies, data) bytes += len(data) } @@ -580,7 +580,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } for _, req := range req.Reqs { // Retrieve the requested state entry, stopping if enough was found - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { statedb, err := pm.blockchain.State() if err != nil { continue @@ -646,7 +646,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { break } // Retrieve the requested block's receipts, skipping if unknown to us - results := core.GetBlockReceipts(pm.chainDb, hash, core.GetBlockNumber(pm.chainDb, hash)) + results := rawdb.GetBlockReceipts(pm.chainDb, hash, rawdb.GetBlockNumber(pm.chainDb, hash), pm.chainConfig) if results == nil { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { continue @@ -706,7 +706,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } for _, req := range req.Reqs { // Retrieve the requested state entry, stopping if enough was found - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { statedb, err := pm.blockchain.State() if err != nil { continue @@ -764,7 +764,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if statedb == nil || req.BHash != lastBHash { statedb, root, lastBHash = nil, common.Hash{}, req.BHash - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + if header := rawdb.GetHeader(pm.chainDb, req.BHash, rawdb.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { statedb, _ = pm.blockchain.State() root = header.Root } @@ -860,7 +860,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { trieDb := trie.NewDatabase(rawdb.NewTable(pm.chainDb, light.ChtTablePrefix)) for _, req := range req.Reqs { if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { - sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.CHTFrequencyServer-1) + sectionHead := rawdb.GetCanonicalHash(pm.chainDb, req.ChtNum*light.CHTFrequencyServer-1) if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) { trie, err := trie.New(root, trieDb) if err != nil { @@ -1095,18 +1095,18 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // getAccount retrieves an account from the state based at root. -func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) { +func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (types.StateAccount, error) { trie, err := trie.New(root, statedb.Database().TrieDB()) if err != nil { - return state.Account{}, err + return types.StateAccount{}, err } - blob, err := trie.TryGet(hash[:]) + blob, err := trie.Get(hash[:]) if err != nil { - return state.Account{}, err + return types.StateAccount{}, err } - var account state.Account + var account types.StateAccount if err = rlp.DecodeBytes(blob, &account); err != nil { - return state.Account{}, err + return types.StateAccount{}, err } return account, nil } @@ -1115,10 +1115,10 @@ func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common. func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { switch id { case htCanonical: - sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.CHTFrequencyClient-1) + sectionHead := rawdb.GetCanonicalHash(pm.chainDb, (idx+1)*light.CHTFrequencyClient-1) return light.GetChtV2Root(pm.chainDb, idx, sectionHead), light.ChtTablePrefix case htBloomBits: - sectionHead := core.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) + sectionHead := rawdb.GetCanonicalHash(pm.chainDb, (idx+1)*light.BloomTrieFrequency-1) return light.GetBloomTrieRoot(pm.chainDb, idx, sectionHead), light.BloomTrieTablePrefix } return common.Hash{}, "" @@ -1129,8 +1129,8 @@ func (pm *ProtocolManager) getHelperTrieAuxData(req HelperTrieReq) []byte { switch { case req.Type == htCanonical && req.AuxReq == auxHeader && len(req.Key) == 8: blockNum := binary.BigEndian.Uint64(req.Key) - hash := core.GetCanonicalHash(pm.chainDb, blockNum) - return core.GetHeaderRLP(pm.chainDb, hash, blockNum) + hash := rawdb.GetCanonicalHash(pm.chainDb, blockNum) + return rawdb.GetHeaderRLP(pm.chainDb, hash, blockNum) } return nil } @@ -1143,9 +1143,9 @@ func (pm *ProtocolManager) txStatus(hashes []common.Hash) []txStatus { // If the transaction is unknown to the pool, try looking it up locally if stat == core.TxStatusUnknown { - if block, number, index := core.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) { + if block, number, index := rawdb.GetTxLookupEntry(pm.chainDb, hashes[i]); block != (common.Hash{}) { stats[i].Status = core.TxStatusIncluded - stats[i].Lookup = &core.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index} + stats[i].Lookup = &rawdb.TxLookupEntry{BlockHash: block, BlockIndex: number, Index: index} } } } diff --git a/les/handler_test.go b/les/handler_test.go index 225900dd52..2492a11f82 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -18,7 +18,6 @@ package les import ( "encoding/binary" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "math/rand" "testing" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/eth/downloader" @@ -304,7 +304,7 @@ func testGetReceipt(t *testing.T, protocol int) { block := bc.GetBlockByNumber(i) hashes = append(hashes, block.Hash()) - receipts = append(receipts, core.GetBlockReceipts(db, block.Hash(), block.NumberU64())) + receipts = append(receipts, rawdb.GetBlockReceipts(db, block.Hash(), block.NumberU64(), pm.chainConfig)) } // Send the hash request and verify the response cost := peer.GetRequestCost(GetReceiptsMsg, len(hashes)) @@ -555,9 +555,9 @@ func TestTransactionStatusLes2(t *testing.T) { } // check if their status is included now - block1hash := core.GetCanonicalHash(db, 1) - test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) - test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &core.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) + block1hash := rawdb.GetCanonicalHash(db, 1) + test(tx1, false, txStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 0}}) + test(tx2, false, txStatus{Status: core.TxStatusIncluded, Lookup: &rawdb.TxLookupEntry{BlockHash: block1hash, BlockIndex: 1, Index: 1}}) // create a reorg that rolls them back gchain, _ = core.GenerateChain(params.TestChainConfig, chain.GetBlockByNumber(0), ethash.NewFaker(), db, 2, func(i int, block *core.BlockGen) {}) diff --git a/les/odr_requests.go b/les/odr_requests.go index e6e68e7621..8bf12f6e8b 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -// Package light implements on-demand retrieval capable state and chain objects +// Package les implements on-demand retrieval capable state and chain objects // for the Ethereum Light Client. package les @@ -24,7 +24,7 @@ import ( "fmt" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" @@ -110,11 +110,11 @@ func (r *BlockRequest) Validate(db ethdb.Database, msg *Msg) error { body := bodies[0] // Retrieve our stored header and validate block content against it - header := core.GetHeader(db, r.Hash, r.Number) + header := rawdb.GetHeader(db, r.Hash, r.Number) if header == nil { return errHeaderUnavailable } - if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions)) { + if header.TxHash != types.DeriveSha(types.Transactions(body.Transactions), new(trie.StackTrie)) { return errTxHashMismatch } if header.UncleHash != types.CalcUncleHash(body.Uncles) { @@ -166,11 +166,11 @@ func (r *ReceiptsRequest) Validate(db ethdb.Database, msg *Msg) error { receipt := receipts[0] // Retrieve our stored header and validate receipt content against it - header := core.GetHeader(db, r.Hash, r.Number) + header := rawdb.GetHeader(db, r.Hash, r.Number) if header == nil { return errHeaderUnavailable } - if header.ReceiptHash != types.DeriveSha(receipt) { + if header.ReceiptHash != types.DeriveSha(receipt, new(trie.StackTrie)) { return errReceiptHashMismatch } // Validations passed, store and return diff --git a/les/odr_test.go b/les/odr_test.go index 3858e34028..52b5d990b3 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -19,7 +19,6 @@ package les import ( "bytes" "context" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" @@ -27,6 +26,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -64,9 +64,9 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, config *params.ChainCon func odrGetReceipts(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + receipts = rawdb.GetBlockReceipts(db, bhash, rawdb.GetBlockNumber(db, bhash), config) } else { - receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + receipts, _ = light.GetBlockReceipts(ctx, lc.Odr(), bhash, rawdb.GetBlockNumber(db, bhash), config) } if receipts == nil { return nil @@ -91,7 +91,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon for _, addr := range acc { if bc != nil { header := bc.GetHeaderByHash(bhash) - st, err = state.New(header.Root, state.NewDatabase(db)) + st, err = state.New(header.Root, state.NewDatabase(db), nil) } else { header := lc.GetHeaderByHash(bhash) st = light.NewState(ctx, header, lc.Odr()) @@ -109,12 +109,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon // //func TestOdrContractCallLes2(t *testing.T) { testOdr(t, 2, 2, odrContractCall) } -type callmsg struct { - types.Message -} - -func (callmsg) CheckNonce() bool { return false } - func odrContractCall(ctx context.Context, db ethdb.Database, config *params.ChainConfig, bc *core.BlockChain, lc *light.LightChain, bhash common.Hash) []byte { data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") @@ -123,7 +117,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai data[35] = byte(i) if bc != nil { header := bc.GetHeaderByHash(bhash) - statedb, err := state.New(header.Root, state.NewDatabase(db)) + statedb, err := state.New(header.Root, state.NewDatabase(db), nil) if err == nil { from := statedb.GetOrNewStateObject(testBankAddress) @@ -133,16 +127,26 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee)} - + fromAddr := from.Address() + msg := &core.Message{ + To: &fromAddr, + From: testContractAddr, + Nonce: 0, + Value: new(big.Int), + GasLimit: 100000, + GasPrice: new(big.Int), + Data: data, + SkipAccountChecks: false, + BalanceTokenFee: balanceTokenFee, + } context := core.NewEVMContext(msg, header, bc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) //vmenv := core.NewEnv(statedb, config, bc, msg, header, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner) - res = append(res, ret...) + ret, _ := core.ApplyMessage(vmenv, msg, gp, owner) + res = append(res, ret.Return()...) } } else { header := lc.GetHeaderByHash(bhash) @@ -153,14 +157,24 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee)} + msg := &core.Message{ + To: &testBankAddress, + From: testContractAddr, + Nonce: 0, + Value: new(big.Int), + GasLimit: 100000, + GasPrice: new(big.Int), + Data: data, + SkipAccountChecks: false, + BalanceTokenFee: balanceTokenFee, + } context := core.NewEVMContext(msg, header, lc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner) + ret, _ := core.ApplyMessage(vmenv, msg, gp, owner) if statedb.Error() == nil { - res = append(res, ret...) + res = append(res, ret.Return()...) } } } @@ -190,7 +204,7 @@ func testOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { test := func(expFail uint64) { for i := uint64(0); i <= pm.blockchain.CurrentHeader().Number.Uint64(); i++ { - bhash := core.GetCanonicalHash(db, i) + bhash := rawdb.GetCanonicalHash(db, i) b1 := fn(light.NoOdr, db, pm.chainConfig, pm.blockchain.(*core.BlockChain), nil, bhash) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) diff --git a/les/protocol.go b/les/protocol.go index 9ca62e73e3..26e4573369 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -28,6 +28,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/secp256k1" "github.com/tomochain/tomochain/rlp" @@ -224,6 +225,6 @@ type proofsData [][]rlp.RawValue type txStatus struct { Status core.TxStatus - Lookup *core.TxLookupEntry `rlp:"nil"` + Lookup *rawdb.TxLookupEntry `rlp:"nil"` Error string } diff --git a/les/request_test.go b/les/request_test.go index 183128d839..2313e738a5 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -18,12 +18,11 @@ package les import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "testing" "time" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/ethdb" @@ -59,7 +58,7 @@ func tfReceiptsAccess(db ethdb.Database, bhash common.Hash, number uint64) light //func TestTrieEntryAccessLes2(t *testing.T) { testAccess(t, 2, tfTrieEntryAccess) } func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { - return &light.TrieRequest{Id: light.StateTrieID(core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} + return &light.TrieRequest{Id: light.StateTrieID(rawdb.GetHeader(db, bhash, rawdb.GetBlockNumber(db, bhash))), Key: testBankSecureTrieKey} } //func TestCodeAccessLes1(t *testing.T) { testAccess(t, 1, tfCodeAccess) } @@ -67,7 +66,7 @@ func tfTrieEntryAccess(db ethdb.Database, bhash common.Hash, number uint64) ligh //func TestCodeAccessLes2(t *testing.T) { testAccess(t, 2, tfCodeAccess) } func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.OdrRequest { - header := core.GetHeader(db, bhash, core.GetBlockNumber(db, bhash)) + header := rawdb.GetHeader(db, bhash, rawdb.GetBlockNumber(db, bhash)) if header.Number.Uint64() < testContractDeployed { return nil } @@ -100,7 +99,7 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { test := func(expFail uint64) { for i := uint64(0); i <= pm.blockchain.CurrentHeader().Number.Uint64(); i++ { - bhash := core.GetCanonicalHash(db, i) + bhash := rawdb.GetCanonicalHash(db, i) if req := fn(ldb, bhash, i); req != nil { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() diff --git a/les/server.go b/les/server.go index b56d2cad4b..4705f599da 100644 --- a/les/server.go +++ b/les/server.go @@ -25,6 +25,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/eth" "github.com/tomochain/tomochain/ethdb" @@ -329,11 +330,11 @@ func (pm *ProtocolManager) blockLoop() { header := ev.Block.Header() hash := header.Hash() number := header.Number.Uint64() - td := core.GetTd(pm.chainDb, hash, number) + td := rawdb.GetTd(pm.chainDb, hash, number) if td != nil && td.Cmp(lastBroadcastTd) > 0 { var reorg uint64 if lastHead != nil { - reorg = lastHead.Number.Uint64() - core.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64() + reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(pm.chainDb, header, lastHead).Number.Uint64() } lastHead = header lastBroadcastTd = td diff --git a/les/sync.go b/les/sync.go index 8e3cd47ca3..993e96a581 100644 --- a/les/sync.go +++ b/les/sync.go @@ -20,7 +20,7 @@ import ( "context" "time" - "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/eth/downloader" "github.com/tomochain/tomochain/light" ) @@ -61,7 +61,7 @@ func (pm *ProtocolManager) syncer() { func (pm *ProtocolManager) needToSync(peerHead blockInfo) bool { head := pm.blockchain.CurrentHeader() - currentTd := core.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64()) + currentTd := rawdb.GetTd(pm.chainDb, head.Hash(), head.Number.Uint64()) return currentTd != nil && peerHead.Td.Cmp(currentTd) > 0 } diff --git a/light/lightchain.go b/light/lightchain.go index 6c91389777..42717f1ced 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -24,10 +24,12 @@ import ( "sync/atomic" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -142,7 +144,7 @@ func (self *LightChain) Odr() OdrBackend { // loadLastState loads the last known chain state from the database. This method // assumes that the chain manager mutex is held. func (self *LightChain) loadLastState() error { - if head := core.GetHeadHeaderHash(self.chainDb); head == (common.Hash{}) { + if head := rawdb.GetHeadHeaderHash(self.chainDb); head == (common.Hash{}) { // Corrupt or empty database, init from scratch self.Reset() } else { @@ -189,10 +191,10 @@ func (bc *LightChain) ResetWithGenesisBlock(genesis *types.Block) { defer bc.mu.Unlock() // Prepare the genesis block and reinitialise the chain - if err := core.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { + if err := rawdb.WriteTd(bc.chainDb, genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { log.Crit("Failed to write genesis block TD", "err", err) } - if err := core.WriteBlock(bc.chainDb, genesis); err != nil { + if err := rawdb.WriteBlock(bc.chainDb, genesis); err != nil { log.Crit("Failed to write genesis block", "err", err) } bc.genesisBlock = genesis diff --git a/light/lightchain_test.go b/light/lightchain_test.go index 21836cc88c..073efecb00 100644 --- a/light/lightchain_test.go +++ b/light/lightchain_test.go @@ -18,13 +18,13 @@ package light import ( "context" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/params" @@ -123,8 +123,8 @@ func testHeaderChainImport(chain []*types.Header, lightchain *LightChain) error } // Manually insert the header into the database, but don't reorganize (allows subsequent testing) lightchain.mu.Lock() - core.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash))) - core.WriteHeader(lightchain.chainDb, header) + rawdb.WriteTd(lightchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, lightchain.GetTdByHash(header.ParentHash))) + rawdb.WriteHeader(lightchain.chainDb, header) lightchain.mu.Unlock() } return nil diff --git a/light/odr.go b/light/odr.go index b5591fdd93..9fe919cb39 100644 --- a/light/odr.go +++ b/light/odr.go @@ -24,6 +24,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" ) @@ -112,7 +113,7 @@ type BlockRequest struct { // StoreResult stores the retrieved data in local database func (req *BlockRequest) StoreResult(db ethdb.Database) { - core.WriteBodyRLP(db, req.Hash, req.Number, req.Rlp) + rawdb.WriteBodyRLP(db, req.Hash, req.Number, req.Rlp) } // ReceiptsRequest is the ODR request type for retrieving block bodies @@ -125,7 +126,7 @@ type ReceiptsRequest struct { // StoreResult stores the retrieved data in local database func (req *ReceiptsRequest) StoreResult(db ethdb.Database) { - core.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts) + rawdb.WriteBlockReceipts(db, req.Hash, req.Number, req.Receipts) } // ChtRequest is the ODR request type for state/storage trie entries @@ -141,10 +142,10 @@ type ChtRequest struct { // StoreResult stores the retrieved data in local database func (req *ChtRequest) StoreResult(db ethdb.Database) { // if there is a canonical hash, there is a header too - core.WriteHeader(db, req.Header) + rawdb.WriteHeader(db, req.Header) hash, num := req.Header.Hash(), req.Header.Number.Uint64() - core.WriteTd(db, hash, num, req.Td) - core.WriteCanonicalHash(db, hash, num) + rawdb.WriteTd(db, hash, num, req.Td) + rawdb.WriteCanonicalHash(db, hash, num) } // BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure @@ -161,11 +162,11 @@ type BloomRequest struct { // StoreResult stores the retrieved data in local database func (req *BloomRequest) StoreResult(db ethdb.Database) { for i, sectionIdx := range req.SectionIdxList { - sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) // if we don't have the canonical hash stored for this section head number, we'll still store it under // a key with a zero sectionHead. GetBloomBits will look there too if we still don't have the canonical // hash. In the unlikely case we've retrieved the section head hash since then, we'll just retrieve the // bit vector again from the network. - core.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i]) + rawdb.WriteBloomBits(db, req.BitIdx, sectionIdx, sectionHead, req.BloomBits[i]) } } diff --git a/light/odr_test.go b/light/odr_test.go index 0c5fc78573..c4b9e10648 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -20,16 +20,16 @@ import ( "bytes" "context" "errors" - "github.com/tomochain/tomochain/consensus" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "testing" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/math" + "github.com/tomochain/tomochain/consensus" "github.com/tomochain/tomochain/consensus/ethash" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/core/vm" @@ -43,7 +43,7 @@ import ( var ( testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") testBankAddress = crypto.PubkeyToAddress(testBankKey.PublicKey) - testBankFunds = big.NewInt(100000000) + testBankFunds = big.NewInt(1_000_000_000_000_000_000) acc1Key, _ = crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") acc2Key, _ = crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") @@ -72,9 +72,12 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { } switch req := req.(type) { case *BlockRequest: - req.Rlp = core.GetBodyRLP(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) + req.Rlp = rawdb.GetBodyRLP(odr.sdb, req.Hash, rawdb.GetBlockNumber(odr.sdb, req.Hash)) case *ReceiptsRequest: - req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) + number := rawdb.GetBlockNumber(odr.sdb, req.Hash) + if number != rawdb.MissingNumber { + req.Receipts = rawdb.ReadRawReceipts(odr.sdb, req.Hash, number) + } case *TrieRequest: t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb)) nodes := NewNodeSet() @@ -110,9 +113,13 @@ func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) } func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) { var receipts types.Receipts if bc != nil { - receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) + if number := rawdb.GetBlockNumber(db, bhash); number != rawdb.MissingNumber { + if header := rawdb.GetHeader(db, bhash, number); header != nil { + receipts = rawdb.GetBlockReceipts(db, bhash, number, bc.Config()) + } + } } else { - receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) + receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, rawdb.GetBlockNumber(db, bhash), lc.Config()) } if receipts == nil { return nil, nil @@ -133,7 +140,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc st = NewState(ctx, header, lc.Odr()) } else { header := bc.GetHeaderByHash(bhash) - st, _ = state.New(header.Root, state.NewDatabase(db)) + st, _ = state.New(header.Root, state.NewDatabase(db), nil) } var res []byte @@ -148,7 +155,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) } type callmsg struct { - types.Message + core.Message } func (callmsg) CheckNonce() bool { return false } @@ -173,7 +180,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain } else { chain = bc header = bc.GetHeaderByHash(bhash) - st, _ = state.New(header.Root, state.NewDatabase(db)) + st, _ = state.New(header.Root, state.NewDatabase(db), nil) } // Perform read-only call. @@ -183,13 +190,22 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, false, balanceTokenFee)} + msg := &core.Message{ + From: testBankAddress, + To: &testContractAddr, + Value: new(big.Int), + GasLimit: 1000000, + GasPrice: new(big.Int), + Data: data, + SkipAccountChecks: true, + BalanceTokenFee: balanceTokenFee, + } context := core.NewEVMContext(msg, header, chain, nil) vmenv := vm.NewEVM(context, st, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} - ret, _, _, _ := core.ApplyMessage(vmenv, msg, gp, owner) - res = append(res, ret...) + ret, _ := core.ApplyMessage(vmenv, msg, gp, owner) + res = append(res, ret.Return()...) if st.Error() != nil { return res, st.Error() } @@ -202,17 +218,17 @@ func testChainGen(i int, block *core.BlockGen) { switch i { case 0: // In block 1, the test bank sends account #1 some ether. - tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey) + tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(10_000_000_000_000_000), params.TxGas, nil, nil), signer, testBankKey) block.AddTx(tx) case 1: // In block 2, the test bank sends some more ether to account #1. // acc1Addr passes it on to account #2. // acc1Addr creates a test contract. - tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey) + tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBankAddress), acc1Addr, big.NewInt(1_000_000_000_000_000), params.TxGas, nil, nil), signer, testBankKey) nonce := block.TxNonce(acc1Addr) - tx2, _ := types.SignTx(types.NewTransaction(nonce, acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key) + tx2, _ := types.SignTx(types.NewTransaction(nonce, acc2Addr, big.NewInt(1_000_000_000_000_000), params.TxGas, nil, nil), signer, acc1Key) nonce++ - tx3, _ := types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 1000000, big.NewInt(0), testContractCode), signer, acc1Key) + tx3, _ := types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 1000000, nil, testContractCode), signer, acc1Key) testContractAddr = crypto.CreateAddress(acc1Addr, nonce) block.AddTx(tx1) block.AddTx(tx2) @@ -240,9 +256,12 @@ func testChainGen(i int, block *core.BlockGen) { func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { var ( - sdb = rawdb.NewMemoryDatabase() - ldb = rawdb.NewMemoryDatabase() - gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}} + sdb = rawdb.NewMemoryDatabase() + ldb = rawdb.NewMemoryDatabase() + gspec = core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}, + } genesis = gspec.MustCommit(sdb) ) gspec.MustCommit(ldb) @@ -268,7 +287,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { test := func(expFail int) { for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ { - bhash := core.GetCanonicalHash(sdb, i) + bhash := rawdb.GetCanonicalHash(sdb, i) b1, err := fn(NoOdr, sdb, blockchain, nil, bhash) if err != nil { t.Fatalf("error in full-node test for block %d: %v", i, err) diff --git a/light/odr_util.go b/light/odr_util.go index 89a63eb2b9..9d38c45b28 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -22,8 +22,10 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" + "github.com/tomochain/tomochain/params" "github.com/tomochain/tomochain/rlp" ) @@ -31,10 +33,10 @@ var sha3_nil = crypto.Keccak256Hash(nil) func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) { db := odr.Database() - hash := core.GetCanonicalHash(db, number) + hash := rawdb.GetCanonicalHash(db, number) if (hash != common.Hash{}) { // if there is a canonical hash, there is a header too - header := core.GetHeader(db, hash, number) + header := rawdb.GetHeader(db, hash, number) if header == nil { panic("Canonical hash present but header not found") } @@ -47,14 +49,14 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ ) if odr.ChtIndexer() != nil { chtCount, sectionHeadNum, sectionHead = odr.ChtIndexer().Sections() - canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash := rawdb.GetCanonicalHash(db, sectionHeadNum) // if the CHT was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too for chtCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { chtCount-- if chtCount > 0 { sectionHeadNum = chtCount*CHTFrequencyClient - 1 sectionHead = odr.ChtIndexer().SectionHead(chtCount - 1) - canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash = rawdb.GetCanonicalHash(db, sectionHeadNum) } } } @@ -69,7 +71,7 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ } func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (common.Hash, error) { - hash := core.GetCanonicalHash(odr.Database(), number) + hash := rawdb.GetCanonicalHash(odr.Database(), number) if (hash != common.Hash{}) { return hash, nil } @@ -82,7 +84,7 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) { - if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil { + if data := rawdb.GetBodyRLP(odr.Database(), hash, number); data != nil { return data, nil } r := &BlockRequest{Hash: hash, Number: number} @@ -111,7 +113,7 @@ func GetBody(ctx context.Context, odr OdrBackend, hash common.Hash, number uint6 // back from the stored header and body. func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (*types.Block, error) { // Retrieve the block header and body contents - header := core.GetHeader(odr.Database(), hash, number) + header := rawdb.GetHeader(odr.Database(), hash, number) if header == nil { return nil, ErrNoHeader } @@ -125,9 +127,9 @@ func GetBlock(ctx context.Context, odr OdrBackend, hash common.Hash, number uint // GetBlockReceipts retrieves the receipts generated by the transactions included // in a block given by its hash. -func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (types.Receipts, error) { +func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64, config *params.ChainConfig) (types.Receipts, error) { // Retrieve the potentially incomplete receipts from disk or network - receipts := core.GetBlockReceipts(odr.Database(), hash, number) + receipts := rawdb.GetBlockReceipts(odr.Database(), hash, number, config) if receipts == nil { r := &ReceiptsRequest{Hash: hash, Number: number} if err := odr.Retrieve(ctx, r); err != nil { @@ -141,22 +143,22 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num if err != nil { return nil, err } - genesis := core.GetCanonicalHash(odr.Database(), 0) - config, _ := core.GetChainConfig(odr.Database(), genesis) + genesis := rawdb.GetCanonicalHash(odr.Database(), 0) + config, _ := rawdb.GetChainConfig(odr.Database(), genesis) if err := core.SetReceiptsData(config, block, receipts); err != nil { return nil, err } - core.WriteBlockReceipts(odr.Database(), hash, number, receipts) + rawdb.WriteBlockReceipts(odr.Database(), hash, number, receipts) } return receipts, nil } // GetBlockLogs retrieves the logs generated by the transactions included in a // block given by its hash. -func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) ([][]*types.Log, error) { +func GetBlockLogs(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64, config *params.ChainConfig) ([][]*types.Log, error) { // Retrieve the potentially incomplete receipts from disk or network - receipts := core.GetBlockReceipts(odr.Database(), hash, number) + receipts := rawdb.GetBlockReceipts(odr.Database(), hash, number, config) if receipts == nil { r := &ReceiptsRequest{Hash: hash, Number: number} if err := odr.Retrieve(ctx, r); err != nil { @@ -187,24 +189,24 @@ func GetBloomBits(ctx context.Context, odr OdrBackend, bitIdx uint, sectionIdxLi ) if odr.BloomTrieIndexer() != nil { bloomTrieCount, sectionHeadNum, sectionHead = odr.BloomTrieIndexer().Sections() - canonicalHash := core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash := rawdb.GetCanonicalHash(db, sectionHeadNum) // if the BloomTrie was injected as a trusted checkpoint, we have no canonical hash yet so we accept zero hash too for bloomTrieCount > 0 && canonicalHash != sectionHead && canonicalHash != (common.Hash{}) { bloomTrieCount-- if bloomTrieCount > 0 { sectionHeadNum = bloomTrieCount*BloomTrieFrequency - 1 sectionHead = odr.BloomTrieIndexer().SectionHead(bloomTrieCount - 1) - canonicalHash = core.GetCanonicalHash(db, sectionHeadNum) + canonicalHash = rawdb.GetCanonicalHash(db, sectionHeadNum) } } } for i, sectionIdx := range sectionIdxList { - sectionHead := core.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) + sectionHead := rawdb.GetCanonicalHash(db, (sectionIdx+1)*BloomTrieFrequency-1) // if we don't have the canonical hash stored for this section head number, we'll still look for // an entry with a zero sectionHead (we store it with zero section head too if we don't know it // at the time of the retrieval) - bloomBits, err := core.GetBloomBits(db, bitIdx, sectionIdx, sectionHead) + bloomBits, err := rawdb.GetBloomBits(db, bitIdx, sectionIdx, sectionHead) if err == nil { result[i] = bloomBits } else { diff --git a/light/postprocess.go b/light/postprocess.go index 1e83a3cd7a..22526d943f 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -19,13 +19,13 @@ package light import ( "encoding/binary" "errors" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "time" "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/bitutil" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -162,7 +162,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) { hash, num := header.Hash(), header.Number.Uint64() c.lastHash = hash - td := core.GetTd(c.diskdb, hash, num) + td := rawdb.GetTd(c.diskdb, hash, num) if td == nil { panic(nil) } @@ -273,7 +273,7 @@ func (b *BloomTrieIndexerBackend) Commit() error { binary.BigEndian.PutUint64(encKey[2:10], b.section) var decomp []byte for j := uint64(0); j < b.bloomTrieRatio; j++ { - data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) + data, err := rawdb.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) if err != nil { return err } diff --git a/light/trie.go b/light/trie.go index d247f145ea..8d32392f46 100644 --- a/light/trie.go +++ b/light/trie.go @@ -26,11 +26,12 @@ import ( "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb" + "github.com/tomochain/tomochain/rlp" "github.com/tomochain/tomochain/trie" ) func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { - state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr)) + state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil) return state } @@ -95,27 +96,74 @@ type odrTrie struct { trie *trie.Trie } -func (t *odrTrie) TryGet(key []byte) ([]byte, error) { +func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { key = crypto.Keccak256(key) - var res []byte + var enc []byte err := t.do(key, func() (err error) { - res, err = t.trie.TryGet(key) + enc, err = t.trie.Get(key) return err }) - return res, err + if err != nil || len(enc) == 0 { + return nil, err + } + _, content, _, err := rlp.Split(enc) + return content, err +} + +func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error) { + var ( + enc []byte + key = crypto.Keccak256(address.Bytes()) + ) + err := t.do(key, func() (err error) { + enc, err = t.trie.Get(key) + return err + }) + if err != nil || len(enc) == 0 { + return nil, err + } + acct := new(types.StateAccount) + if err := rlp.DecodeBytes(enc, acct); err != nil { + return nil, err + } + return acct, nil +} + +func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { + key := crypto.Keccak256(address.Bytes()) + value, err := rlp.EncodeToBytes(acc) + if err != nil { + return fmt.Errorf("decoding error in account update: %w", err) + } + return t.do(key, func() error { + return t.trie.Update(key, value) + }) +} + +func (t *odrTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { + return nil } -func (t *odrTrie) TryUpdate(key, value []byte) error { +func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error { key = crypto.Keccak256(key) + v, _ := rlp.EncodeToBytes(value) return t.do(key, func() error { - return t.trie.TryDelete(key) + return t.trie.Update(key, v) }) } -func (t *odrTrie) TryDelete(key []byte) error { +func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { - return t.trie.TryDelete(key) + return t.trie.Delete(key) + }) +} + +// DeleteAccount abstracts an account deletion from the trie. +func (t *odrTrie) DeleteAccount(address common.Address) error { + key := crypto.Keccak256(address.Bytes()) + return t.do(key, func() error { + return t.trie.Delete(key) }) } diff --git a/light/txpool.go b/light/txpool.go index 7af86dbd6b..d5bee2a3ba 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -24,6 +24,7 @@ import ( "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/core/types" "github.com/tomochain/tomochain/ethdb" @@ -74,10 +75,13 @@ type TxPool struct { // // Send instructs backend to forward new transactions // NewHead notifies backend about a new head after processed by the tx pool, -// including mined and rolled back transactions since the last event +// +// including mined and rolled back transactions since the last event +// // Discard notifies backend about transactions that should be discarded either -// because they have been replaced by a re-send or because they have been mined -// long ago and no rollback is expected +// +// because they have been replaced by a re-send or because they have been mined +// long ago and no rollback is expected type TxRelayBackend interface { Send(txs types.Transactions) NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash) @@ -180,10 +184,10 @@ func (pool *TxPool) checkMinedTxs(ctx context.Context, hash common.Hash, number // If some transactions have been mined, write the needed data to disk and update if list != nil { // Retrieve all the receipts belonging to this block and write the loopup table - if _, err := GetBlockReceipts(ctx, pool.odr, hash, number); err != nil { // ODR caches, ignore results + if _, err := GetBlockReceipts(ctx, pool.odr, hash, number, pool.config); err != nil { // ODR caches, ignore results return err } - if err := core.WriteTxLookupEntries(pool.chainDb, block); err != nil { + if err := rawdb.WriteTxLookupEntries(pool.chainDb, block); err != nil { return err } // Update the transaction pool's state @@ -202,7 +206,7 @@ func (pool *TxPool) rollbackTxs(hash common.Hash, txc txStateChanges) { if list, ok := pool.mined[hash]; ok { for _, tx := range list { txHash := tx.Hash() - core.DeleteTxLookupEntry(pool.chainDb, txHash) + rawdb.DeleteTxLookupEntry(pool.chainDb, txHash) pool.pending[txHash] = tx txc.setState(txHash, false) } @@ -258,7 +262,7 @@ func (pool *TxPool) reorgOnNewHead(ctx context.Context, newHeader *types.Header) idx2 := idx - txPermanent if len(pool.mined) > 0 { for i := pool.clearIdx; i < idx2; i++ { - hash := core.GetCanonicalHash(pool.chainDb, i) + hash := rawdb.GetCanonicalHash(pool.chainDb, i) if list, ok := pool.mined[hash]; ok { hashes := make([]common.Hash, len(list)) for i, tx := range list { diff --git a/metrics/metrics.go b/metrics/metrics.go index dbb2727ec0..3e315b19e1 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -19,7 +19,12 @@ import ( // // This global kill-switch helps quantify the observer effect and makes // for less cluttered pprof profiles. -var Enabled bool = false +var Enabled = false + +// EnabledExpensive is a soft-flag meant for external packages to check if costly +// metrics gathering is allowed or not. The goal is to separate standard metrics +// for health monitoring and debug metrics that might impact runtime performance. +var EnabledExpensive = false // MetricsEnabledFlag is the CLI flag name to use to enable metrics collections. const MetricsEnabledFlag = "metrics" diff --git a/miner/worker.go b/miner/worker.go index 995c401690..a8985a2a81 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -23,6 +23,7 @@ import ( "github.com/tomochain/tomochain/accounts" "github.com/tomochain/tomochain/tomoxlending/lendingstate" + "github.com/tomochain/tomochain/trie" "math/big" "os" @@ -204,6 +205,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { self.current.txs, nil, self.current.receipts, + new(trie.Trie), ), self.current.state.Copy() } return self.current.Block, self.current.state.Copy() @@ -219,6 +221,7 @@ func (self *worker) pendingBlock() *types.Block { self.current.txs, nil, self.current.receipts, + new(trie.Trie), ) } return self.current.Block diff --git a/p2p/discover/node_test.go b/p2p/discover/node_test.go index 8e3da2c2aa..ddf8a7bd98 100644 --- a/p2p/discover/node_test.go +++ b/p2p/discover/node_test.go @@ -142,7 +142,7 @@ var parseNodeTests = []struct { { // This test checks that errors from url.Parse are handled. rawurl: "://foo", - wantError: `parse ://foo: missing protocol scheme`, + wantError: `parse "://foo": missing protocol scheme`, }, } diff --git a/p2p/discv5/node_test.go b/p2p/discv5/node_test.go index a28f298252..d0fa6880a3 100644 --- a/p2p/discv5/node_test.go +++ b/p2p/discv5/node_test.go @@ -141,7 +141,7 @@ var parseNodeTests = []struct { { // This test checks that errors from url.Parse are handled. rawurl: "://foo", - wantError: `parse ://foo: missing protocol scheme`, + wantError: `parse "://foo": missing protocol scheme`, }, } diff --git a/params/config.go b/params/config.go index 056457d9e8..d5e46dc88a 100644 --- a/params/config.go +++ b/params/config.go @@ -102,16 +102,16 @@ var ( // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil, nil} + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil, nil} // AllPosvProtocolChanges contains every protocol change (EIPs) introduced // and accepted by the Ethereum core developers into the Posv consensus. // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllPosvProtocolChanges = &ChainConfig{big.NewInt(89), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, &PosvConfig{Period: 0, Epoch: 30000}} - AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}, nil} - TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil, nil} + AllPosvProtocolChanges = &ChainConfig{big.NewInt(89), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), nil, big.NewInt(0), nil, nil, nil, &PosvConfig{Period: 0, Epoch: 30000}} + AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), nil, big.NewInt(0), nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}, nil} + TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil, nil} TestRules = TestChainConfig.Rules(new(big.Int)) ) @@ -132,8 +132,9 @@ type ChainConfig struct { EIP150Block *big.Int `json:"eip150Block,omitempty"` // EIP150 HF block (nil = no fork) EIP150Hash common.Hash `json:"eip150Hash,omitempty"` // EIP150 HF hash (needed for header only clients as only gas pricing changed) - EIP155Block *big.Int `json:"eip155Block,omitempty"` // EIP155 HF block - EIP158Block *big.Int `json:"eip158Block,omitempty"` // EIP158 HF block + EIP155Block *big.Int `json:"eip155Block,omitempty"` // EIP155 HF block + EIP158Block *big.Int `json:"eip158Block,omitempty"` // EIP158 HF block + EIP2718Block *big.Int `json:"eip2718Block,omitempty"` // EIP2718 HF block (nil = no fork, 0 = already activated) ByzantiumBlock *big.Int `json:"byzantiumBlock,omitempty"` // Byzantium switch block (nil = no fork, 0 = already on byzantium) ConstantinopleBlock *big.Int `json:"constantinopleBlock,omitempty"` // Constantinople switch block (nil = no fork, 0 = already activated) @@ -189,7 +190,7 @@ func (c *ChainConfig) String() string { default: engine = "unknown" } - return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Engine: %v}", + return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v EIP2718: %v Engine: %v}", c.ChainId, c.HomesteadBlock, c.DAOForkBlock, @@ -199,6 +200,7 @@ func (c *ChainConfig) String() string { c.EIP158Block, c.ByzantiumBlock, c.ConstantinopleBlock, + c.EIP2718Block, engine, ) } @@ -273,6 +275,10 @@ func (c *ChainConfig) IsTIPTomoXCancellationFee(num *big.Int) bool { return isForked(common.TIPTomoXCancellationFee, num) } +func (c *ChainConfig) IsEIP2718(num *big.Int) bool { + return isForked(c.EIP2718Block, num) +} + // GasTable returns the gas table corresponding to the current phase (homestead or homestead reprice). // // The returned GasTable's fields shouldn't, under any circumstances, be changed. @@ -401,7 +407,7 @@ func (err *ConfigCompatError) Error() string { // phases. type Rules struct { ChainId *big.Int - IsHomestead, IsEIP150, IsEIP155, IsEIP158 bool + IsHomestead, IsEIP150, IsEIP155, IsEIP158, IsEIP2718 bool IsByzantium, IsConstantinople, IsPetersburg, IsIstanbul bool } @@ -416,6 +422,7 @@ func (c *ChainConfig) Rules(num *big.Int) Rules { IsEIP150: c.IsEIP150(num), IsEIP155: c.IsEIP155(num), IsEIP158: c.IsEIP158(num), + IsEIP2718: c.IsEIP2718(num), IsByzantium: c.IsByzantium(num), IsConstantinople: c.IsConstantinople(num), IsPetersburg: c.IsPetersburg(num), diff --git a/params/version.go b/params/version.go index af4d16e53c..c220c12734 100644 --- a/params/version.go +++ b/params/version.go @@ -21,10 +21,10 @@ import ( ) const ( - VersionMajor = 2 // Major version component of the current release - VersionMinor = 3 // Minor version component of the current release - VersionPatch = 2 // Patch version component of the current release - VersionMeta = "stable" // Version metadata to append to the version string + VersionMajor = 2 // Major version component of the current release + VersionMinor = 4 // Minor version component of the current release + VersionPatch = 0 // Patch version component of the current release + VersionMeta = "dev" // Version metadata to append to the version string ) // Version holds the textual version string. diff --git a/rlp/decode.go b/rlp/decode.go index 60d9dab2b5..ac93c139a9 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -26,100 +26,77 @@ import ( "math/big" "reflect" "strings" + "sync" + + "github.com/holiman/uint256" + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) +//lint:ignore ST1012 EOL is not an error. + +// EOL is returned when the end of the current list +// has been reached during streaming. +var EOL = errors.New("rlp: end of list") + var ( + ErrExpectedString = errors.New("rlp: expected String or Byte") + ErrExpectedList = errors.New("rlp: expected List") + ErrCanonInt = errors.New("rlp: non-canonical integer format") + ErrCanonSize = errors.New("rlp: non-canonical size information") + ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") + ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") + + // internal errors + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") errNoPointer = errors.New("rlp: interface given to Decode must be a pointer") errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil") + errUint256Large = errors.New("rlp: value too large for uint256") + + streamPool = sync.Pool{ + New: func() interface{} { return new(Stream) }, + } ) -// Decoder is implemented by types that require custom RLP -// decoding rules or need to decode into private fields. +// Decoder is implemented by types that require custom RLP decoding rules or need to decode +// into private fields. // -// The DecodeRLP method should read one value from the given -// Stream. It is not forbidden to read less or more, but it might -// be confusing. +// The DecodeRLP method should read one value from the given Stream. It is not forbidden to +// read less or more, but it might be confusing. type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result in the -// value pointed to by val. Val must be a non-nil pointer. If r does -// not implement ByteReader, Decode will do its own buffering. -// -// Decode uses the following type-dependent decoding rules: -// -// If the type implements the Decoder interface, decode calls -// DecodeRLP. -// -// To decode into a pointer, Decode will decode into the value pointed -// to. If the pointer is nil, a new value of the pointer's element -// type is allocated. If the pointer is non-nil, the existing value -// will be reused. -// -// To decode into a struct, Decode expects the input to be an RLP -// list. The decoded elements of the list are assigned to each public -// field in the order given by the struct's definition. The input list -// must contain an element for each decoded field. Decode returns an -// error if there are too few or too many elements. +// Decode parses RLP-encoded data from r and stores the result in the value pointed to by +// val. Please see package-level documentation for the decoding rules. Val must be a +// non-nil pointer. // -// The decoding of struct fields honours certain struct tags, "tail", -// "nil" and "-". +// If r does not implement ByteReader, Decode will do its own buffering. // -// The "-" tag ignores fields. +// Note that Decode does not set an input limit for all readers and may be vulnerable to +// panics cause by huge value sizes. If you need an input limit, use // -// For an explanation of "tail", see the example. -// -// The "nil" tag applies to pointer-typed fields and changes the decoding -// rules for the field such that input values of size zero decode as a nil -// pointer. This tag can be useful when decoding recursive types. -// -// type StructWithEmptyOK struct { -// Foo *[20]byte `rlp:"nil"` -// } -// -// To decode into a slice, the input must be a list and the resulting -// slice will contain the input elements in order. For byte slices, -// the input must be an RLP string. Array types decode similarly, with -// the additional restriction that the number of input elements (or -// bytes) must match the array's length. -// -// To decode into a Go string, the input must be an RLP string. The -// input bytes are taken as-is and will not necessarily be valid UTF-8. -// -// To decode into an unsigned integer type, the input must also be an RLP -// string. The bytes are interpreted as a big endian representation of -// the integer. If the RLP string is larger than the bit size of the -// type, Decode will return an error. Decode also supports *big.Int. -// There is no size limit for big integers. -// -// To decode into an interface value, Decode stores one of these -// in the value: -// -// []interface{}, for RLP lists -// []byte, for RLP strings -// -// Non-empty interface types are not supported, nor are booleans, -// signed integers, floating point numbers, maps, channels and -// functions. -// -// Note that Decode does not set an input limit for all readers -// and may be vulnerable to panics cause by huge value sizes. If -// you need an input limit, use -// -// NewStream(r, limit).Decode(val) +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - // TODO: this could use a Stream from a pool. - return NewStream(r, 0).Decode(val) + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, 0) + return stream.Decode(val) } -// DecodeBytes parses RLP data from b into val. -// Please see the documentation of Decode for the decoding rules. -// The input must contain exactly one value and no trailing data. +// DecodeBytes parses RLP data from b into val. Please see package-level documentation for +// the decoding rules. The input must contain exactly one value and no trailing data. func DecodeBytes(b []byte, val interface{}) error { - // TODO: this could use a Stream from a pool. r := bytes.NewReader(b) - if err := NewStream(r, uint64(len(b))).Decode(val); err != nil { + + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, uint64(len(b))) + if err := stream.Decode(val); err != nil { return err } if r.Len() > 0 { @@ -173,21 +150,26 @@ func addErrorContext(err error, ctx string) error { var ( decoderInterface = reflect.TypeOf(new(Decoder)).Elem() bigInt = reflect.TypeOf(big.Int{}) + u256Int = reflect.TypeOf(uint256.Int{}) ) -func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { +func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) { kind := typ.Kind() switch { case typ == rawValueType: return decodeRawValue, nil - case typ.Implements(decoderInterface): - return decodeDecoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): - return decodeDecoderNoPtr, nil case typ.AssignableTo(reflect.PtrTo(bigInt)): return decodeBigInt, nil case typ.AssignableTo(bigInt): return decodeBigIntNoPtr, nil + case typ == reflect.PtrTo(u256Int): + return decodeU256, nil + case typ == u256Int: + return decodeU256NoPtr, nil + case kind == reflect.Ptr: + return makePtrDecoder(typ, tags) + case reflect.PtrTo(typ).Implements(decoderInterface): + return decodeDecoder, nil case isUint(kind): return decodeUint, nil case kind == reflect.Bool: @@ -198,11 +180,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { return makeListDecoder(typ, tags) case kind == reflect.Struct: return makeStructDecoder(typ) - case kind == reflect.Ptr: - if tags.nilOK { - return makeOptionalPtrDecoder(typ) - } - return makePtrDecoder(typ) case kind == reflect.Interface: return decodeInterface, nil default: @@ -252,35 +229,48 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { } func decodeBigInt(s *Stream, val reflect.Value) error { - b, err := s.Bytes() + i := val.Interface().(*big.Int) + if i == nil { + i = new(big.Int) + val.Set(reflect.ValueOf(i)) + } + + err := s.decodeBigInt(i) if err != nil { return wrapStreamError(err, val.Type()) } - i := val.Interface().(*big.Int) + return nil +} + +func decodeU256NoPtr(s *Stream, val reflect.Value) error { + return decodeU256(s, val.Addr()) +} + +func decodeU256(s *Stream, val reflect.Value) error { + i := val.Interface().(*uint256.Int) if i == nil { - i = new(big.Int) + i = new(uint256.Int) val.Set(reflect.ValueOf(i)) } - // Reject leading zero bytes - if len(b) > 0 && b[0] == 0 { - return wrapStreamError(ErrCanonInt, val.Type()) + + err := s.ReadUint256(i) + if err != nil { + return wrapStreamError(err, val.Type()) } - i.SetBytes(b) return nil } -func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { +func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { if typ.Kind() == reflect.Array { return decodeByteArray, nil - } else { - return decodeByteSlice, nil } + return decodeByteSlice, nil } - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } var dec decoder switch { @@ -288,7 +278,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { dec = func(s *Stream, val reflect.Value) error { return decodeListArray(s, val, etypeinfo.decoder) } - case tag.tail: + case tag.Tail: // A slice with "tail" tag can occur as the last field // of a struct and is supposed to swallow all remaining // list elements. The struct decoder already called s.List, @@ -381,25 +371,23 @@ func decodeByteArray(s *Stream, val reflect.Value) error { if err != nil { return err } - vlen := val.Len() + slice := byteArrayBytes(val, val.Len()) switch kind { case Byte: - if vlen == 0 { + if len(slice) == 0 { return &decodeError{msg: "input string too long", typ: val.Type()} - } - if vlen > 1 { + } else if len(slice) > 1 { return &decodeError{msg: "input string too short", typ: val.Type()} } - bv, _ := s.Uint() - val.Index(0).SetUint(bv) + slice[0] = s.byteval + s.kind = -1 case String: - if uint64(vlen) < size { + if uint64(len(slice)) < size { return &decodeError{msg: "input string too long", typ: val.Type()} } - if uint64(vlen) > size { + if uint64(len(slice)) > size { return &decodeError{msg: "input string too short", typ: val.Type()} } - slice := val.Slice(0, vlen).Interface().([]byte) if err := s.readFull(slice); err != nil { return err } @@ -418,13 +406,25 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { if err != nil { return nil, err } + for _, f := range fields { + if f.info.decoderErr != nil { + return nil, structFieldError{typ, f.index, f.info.decoderErr} + } + } dec := func(s *Stream, val reflect.Value) (err error) { if _, err := s.List(); err != nil { return wrapStreamError(err, typ) } - for _, f := range fields { + for i, f := range fields { err := f.info.decoder(s, val.Field(f.index)) if err == EOL { + if f.optional { + // The field is optional, so reaching the end of the list before + // reaching the last field is acceptable. All remaining undecoded + // fields are zeroed. + zeroFields(val, fields[i:]) + break + } return &decodeError{msg: "too few elements", typ: typ} } else if err != nil { return addErrorContext(err, "."+typ.Field(f.index).Name) @@ -435,15 +435,29 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } -// makePtrDecoder creates a decoder that decodes into -// the pointer's element type. -func makePtrDecoder(typ reflect.Type) (decoder, error) { +func zeroFields(structval reflect.Value, fields []field) { + for _, f := range fields { + fv := structval.Field(f.index) + fv.Set(reflect.Zero(fv.Type())) + } +} + +// makePtrDecoder creates a decoder that decodes into the pointer's element type. +func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + switch { + case etypeinfo.decoderErr != nil: + return nil, etypeinfo.decoderErr + case !tag.NilOK: + return makeSimplePtrDecoder(etype, etypeinfo), nil + default: + return makeNilPtrDecoder(etype, etypeinfo, tag), nil } - dec := func(s *Stream, val reflect.Value) (err error) { +} + +func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder { + return func(s *Stream, val reflect.Value) (err error) { newval := val if val.IsNil() { newval = reflect.New(etype) @@ -453,30 +467,39 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } -// makeOptionalPtrDecoder creates a decoder that decodes empty values -// as nil. Non-empty values are decoded into a value of the element type, -// just like makePtrDecoder does. +// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty +// values are decoded into a value of the element type, just like makePtrDecoder does. // // This decoder is used for pointer-typed struct fields with struct tag "nil". -func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { - etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err - } - dec := func(s *Stream, val reflect.Value) (err error) { +func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder { + typ := reflect.PtrTo(etype) + nilPtr := reflect.Zero(typ) + + // Determine the value kind that results in nil pointer. + nilKind := typeNilKind(etype, ts) + + return func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() - if err != nil || size == 0 && kind != Byte { + if err != nil { + val.Set(nilPtr) + return wrapStreamError(err, typ) + } + // Handle empty values as a nil pointer. + if kind != Byte && size == 0 { + if kind != nilKind { + return &decodeError{ + msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind), + typ: typ, + } + } // rearm s.Kind. This is important because the input // position must advance to the next value even though // we don't read anything. s.kind = -1 - // set the pointer to nil. - val.Set(reflect.Zero(typ)) - return err + val.Set(nilPtr) + return nil } newval := val if val.IsNil() { @@ -487,7 +510,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } var ifsliceType = reflect.TypeOf([]interface{}{}) @@ -516,25 +538,12 @@ func decodeInterface(s *Stream, val reflect.Value) error { return nil } -// This decoder is used for non-pointer values of types -// that implement the Decoder interface using a pointer receiver. -func decodeDecoderNoPtr(s *Stream, val reflect.Value) error { - return val.Addr().Interface().(Decoder).DecodeRLP(s) -} - func decodeDecoder(s *Stream, val reflect.Value) error { - // Decoder instances are not handled using the pointer rule if the type - // implements Decoder with pointer receiver (i.e. always) - // because it might handle empty values specially. - // We need to allocate one here in this case, like makePtrDecoder does. - if val.Kind() == reflect.Ptr && val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - return val.Interface().(Decoder).DecodeRLP(s) + return val.Addr().Interface().(Decoder).DecodeRLP(s) } // Kind represents the kind of value contained in an RLP stream. -type Kind int +type Kind int8 const ( Byte Kind = iota @@ -555,29 +564,6 @@ func (k Kind) String() string { } } -var ( - // EOL is returned when the end of the current list - // has been reached during streaming. - EOL = errors.New("rlp: end of list") - - // Actual Errors - ErrExpectedString = errors.New("rlp: expected String or Byte") - ErrExpectedList = errors.New("rlp: expected List") - ErrCanonInt = errors.New("rlp: non-canonical integer format") - ErrCanonSize = errors.New("rlp: non-canonical size information") - ErrElemTooLarge = errors.New("rlp: element is larger than containing list") - ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") - - // This error is reported by DecodeBytes if the slice contains - // additional data after the first RLP value. - ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") - - // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") - errUintOverflow = errors.New("rlp: uint overflow") -) - // ByteReader must be implemented by any input reader for a Stream. It // is implemented by e.g. bufio.Reader and bytes.Reader. type ByteReader interface { @@ -600,22 +586,16 @@ type ByteReader interface { type Stream struct { r ByteReader - // number of bytes remaining to be read from r. - remaining uint64 - limited bool - - // auxiliary buffer for integer decoding - uintbuf []byte - - kind Kind // kind of value ahead - size uint64 // size of value ahead - byteval byte // value of single byte in type tag - kinderr error // error from last readKind - stack []listpos + remaining uint64 // number of bytes remaining to be read from r + size uint64 // size of value ahead + kinderr error // error from last readKind + stack []uint64 // list sizes + uintbuf [32]byte // auxiliary buffer for integer decoding + kind Kind // kind of value ahead + byteval byte // value of single byte in type tag + limited bool // true if input limit is in effect } -type listpos struct{ pos, size uint64 } - // NewStream creates a new decoding stream reading from r. // // If r implements the ByteReader interface, Stream will @@ -675,6 +655,37 @@ func (s *Stream) Bytes() ([]byte, error) { } } +// ReadBytes decodes the next RLP value and stores the result in b. +// The value size must match len(b) exactly. +func (s *Stream) ReadBytes(b []byte) error { + kind, size, err := s.Kind() + if err != nil { + return err + } + switch kind { + case Byte: + if len(b) != 1 { + return fmt.Errorf("input value has wrong size 1, want %d", len(b)) + } + b[0] = s.byteval + s.kind = -1 // rearm Kind + return nil + case String: + if uint64(len(b)) != size { + return fmt.Errorf("input value has wrong size %d, want %d", size, len(b)) + } + if err = s.readFull(b); err != nil { + return err + } + if size == 1 && b[0] < 128 { + return ErrCanonSize + } + return nil + default: + return ErrExpectedString + } +} + // Raw reads a raw encoded value including RLP type information. func (s *Stream) Raw() ([]byte, error) { kind, size, err := s.Kind() @@ -685,8 +696,8 @@ func (s *Stream) Raw() ([]byte, error) { s.kind = -1 // rearm Kind return []byte{s.byteval}, nil } - // the original header has already been read and is no longer - // available. read content and put a new header in front of it. + // The original header has already been read and is no longer + // available. Read content and put a new header in front of it. start := headsize(size) buf := make([]byte, uint64(start)+size) if err := s.readFull(buf[start:]); err != nil { @@ -703,10 +714,31 @@ func (s *Stream) Raw() ([]byte, error) { // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. +// +// Deprecated: use s.Uint64 instead. func (s *Stream) Uint() (uint64, error) { return s.uint(64) } +func (s *Stream) Uint64() (uint64, error) { + return s.uint(64) +} + +func (s *Stream) Uint32() (uint32, error) { + i, err := s.uint(32) + return uint32(i), err +} + +func (s *Stream) Uint16() (uint16, error) { + i, err := s.uint(16) + return uint16(i), err +} + +func (s *Stream) Uint8() (uint8, error) { + i, err := s.uint(8) + return uint8(i), err +} + func (s *Stream) uint(maxbits int) (uint64, error) { kind, size, err := s.Kind() if err != nil { @@ -769,7 +801,14 @@ func (s *Stream) List() (size uint64, err error) { if kind != List { return 0, ErrExpectedList } - s.stack = append(s.stack, listpos{0, size}) + + // Remove size of inner list from outer list before pushing the new size + // onto the stack. This ensures that the remaining outer list size will + // be correct after the matching call to ListEnd. + if inList, limit := s.listLimit(); inList { + s.stack[len(s.stack)-1] = limit - size + } + s.stack = append(s.stack, size) s.kind = -1 s.size = 0 return size, nil @@ -778,22 +817,116 @@ func (s *Stream) List() (size uint64, err error) { // ListEnd returns to the enclosing list. // The input reader must be positioned at the end of a list. func (s *Stream) ListEnd() error { - if len(s.stack) == 0 { + // Ensure that no more data is remaining in the current list. + if inList, listLimit := s.listLimit(); !inList { return errNotInList - } - tos := s.stack[len(s.stack)-1] - if tos.pos != tos.size { + } else if listLimit > 0 { return errNotAtEOL } s.stack = s.stack[:len(s.stack)-1] // pop - if len(s.stack) > 0 { - s.stack[len(s.stack)-1].pos += tos.size - } s.kind = -1 s.size = 0 return nil } +// MoreDataInList reports whether the current list context contains +// more data to be read. +func (s *Stream) MoreDataInList() bool { + _, listLimit := s.listLimit() + return listLimit > 0 +} + +// BigInt decodes an arbitrary-size integer value. +func (s *Stream) BigInt() (*big.Int, error) { + i := new(big.Int) + if err := s.decodeBigInt(i); err != nil { + return nil, err + } + return i, nil +} + +func (s *Stream) decodeBigInt(dst *big.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // For integers smaller than s.uintbuf, allocating a buffer + // can be avoided. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + // For large integers, a temporary buffer is needed. + buffer = make([]byte, size) + if err := s.readFull(buffer); err != nil { + return err + } + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + +// ReadUint256 decodes the next value as a uint256. +func (s *Stream) ReadUint256(dst *uint256.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // All possible uint256 values fit into s.uintbuf. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + return errUint256Large + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + // Decode decodes a value and stores the result in the value pointed // to by val. Please see the documentation for the Decode function // to learn about the decoding rules. @@ -809,14 +942,14 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem(), tags{}) + decoder, err := cachedDecoder(rtyp.Elem()) if err != nil { return err } - err = info.decoder(s, rval.Elem()) + err = decoder(s, rval.Elem()) if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { - // add decode target type to error so context has more meaning + // Add decode target type to error so context has more meaning. decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) } return err @@ -839,6 +972,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { case *bytes.Reader: s.remaining = uint64(br.Len()) s.limited = true + case *bytes.Buffer: + s.remaining = uint64(br.Len()) + s.limited = true case *strings.Reader: s.remaining = uint64(br.Len()) s.limited = true @@ -857,9 +993,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { s.size = 0 s.kind = -1 s.kinderr = nil - if s.uintbuf == nil { - s.uintbuf = make([]byte, 8) - } + s.byteval = 0 + s.uintbuf = [32]byte{} } // Kind returns the kind and size of the next value in the @@ -874,35 +1009,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { // the value. Subsequent calls to Kind (until the value is decoded) // will not advance the input reader and return cached information. func (s *Stream) Kind() (kind Kind, size uint64, err error) { - var tos *listpos - if len(s.stack) > 0 { - tos = &s.stack[len(s.stack)-1] - } - if s.kind < 0 { - s.kinderr = nil - // Don't read further if we're at the end of the - // innermost list. - if tos != nil && tos.pos == tos.size { - return 0, 0, EOL - } - s.kind, s.size, s.kinderr = s.readKind() - if s.kinderr == nil { - if tos == nil { - // At toplevel, check that the value is smaller - // than the remaining input length. - if s.limited && s.size > s.remaining { - s.kinderr = ErrValueTooLarge - } - } else { - // Inside a list, check that the value doesn't overflow the list. - if s.size > tos.size-tos.pos { - s.kinderr = ErrElemTooLarge - } - } + if s.kind >= 0 { + return s.kind, s.size, s.kinderr + } + + // Check for end of list. This needs to be done here because readKind + // checks against the list size, and would return the wrong error. + inList, listLimit := s.listLimit() + if inList && listLimit == 0 { + return 0, 0, EOL + } + // Read the actual size tag. + s.kind, s.size, s.kinderr = s.readKind() + if s.kinderr == nil { + // Check the data size of the value ahead against input limits. This + // is done here because many decoders require allocating an input + // buffer matching the value size. Checking it here protects those + // decoders from inputs declaring very large value size. + if inList && s.size > listLimit { + s.kinderr = ErrElemTooLarge + } else if s.limited && s.size > s.remaining { + s.kinderr = ErrValueTooLarge } } - // Note: this might return a sticky error generated - // by an earlier call to readKind. return s.kind, s.size, s.kinderr } @@ -929,37 +1058,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) { s.byteval = b return Byte, 0, nil case b < 0xB8: - // Otherwise, if a string is 0-55 bytes long, - // the RLP encoding consists of a single byte with value 0x80 plus the - // length of the string followed by the string. The range of the first - // byte is thus [0x80, 0xB7]. + // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists + // of a single byte with value 0x80 plus the length of the string + // followed by the string. The range of the first byte is thus [0x80, 0xB7]. return String, uint64(b - 0x80), nil case b < 0xC0: - // If a string is more than 55 bytes long, the - // RLP encoding consists of a single byte with value 0xB7 plus the length - // of the length of the string in binary form, followed by the length of - // the string, followed by the string. For example, a length-1024 string - // would be encoded as 0xB90400 followed by the string. The range of - // the first byte is thus [0xB8, 0xBF]. + // If a string is more than 55 bytes long, the RLP encoding consists of a + // single byte with value 0xB7 plus the length of the length of the + // string in binary form, followed by the length of the string, followed + // by the string. For example, a length-1024 string would be encoded as + // 0xB90400 followed by the string. The range of the first byte is thus + // [0xB8, 0xBF]. size, err = s.readUint(b - 0xB7) if err == nil && size < 56 { err = ErrCanonSize } return String, size, err case b < 0xF8: - // If the total payload of a list - // (i.e. the combined length of all its items) is 0-55 bytes long, the - // RLP encoding consists of a single byte with value 0xC0 plus the length - // of the list followed by the concatenation of the RLP encodings of the - // items. The range of the first byte is thus [0xC0, 0xF7]. + // If the total payload of a list (i.e. the combined length of all its + // items) is 0-55 bytes long, the RLP encoding consists of a single byte + // with value 0xC0 plus the length of the list followed by the + // concatenation of the RLP encodings of the items. The range of the + // first byte is thus [0xC0, 0xF7]. return List, uint64(b - 0xC0), nil default: - // If the total payload of a list is more than 55 bytes long, - // the RLP encoding consists of a single byte with value 0xF7 - // plus the length of the length of the payload in binary - // form, followed by the length of the payload, followed by - // the concatenation of the RLP encodings of the items. The - // range of the first byte is thus [0xF8, 0xFF]. + // If the total payload of a list is more than 55 bytes long, the RLP + // encoding consists of a single byte with value 0xF7 plus the length of + // the length of the payload in binary form, followed by the length of + // the payload, followed by the concatenation of the RLP encodings of + // the items. The range of the first byte is thus [0xF8, 0xFF]. size, err = s.readUint(b - 0xF7) if err == nil && size < 56 { err = ErrCanonSize @@ -977,23 +1104,24 @@ func (s *Stream) readUint(size byte) (uint64, error) { b, err := s.readByte() return uint64(b), err default: - start := int(8 - size) - for i := 0; i < start; i++ { - s.uintbuf[i] = 0 + buffer := s.uintbuf[:8] + for i := range buffer { + buffer[i] = 0 } - if err := s.readFull(s.uintbuf[start:]); err != nil { + start := int(8 - size) + if err := s.readFull(buffer[start:]); err != nil { return 0, err } - if s.uintbuf[start] == 0 { - // Note: readUint is also used to decode integer - // values. The error needs to be adjusted to become - // ErrCanonInt in this case. + if buffer[start] == 0 { + // Note: readUint is also used to decode integer values. + // The error needs to be adjusted to become ErrCanonInt in this case. return 0, ErrCanonSize } - return binary.BigEndian.Uint64(s.uintbuf), nil + return binary.BigEndian.Uint64(buffer[:]), nil } } +// readFull reads into buf from the underlying stream. func (s *Stream) readFull(buf []byte) (err error) { if err := s.willRead(uint64(len(buf))); err != nil { return err @@ -1004,11 +1132,18 @@ func (s *Stream) readFull(buf []byte) (err error) { n += nn } if err == io.EOF { - err = io.ErrUnexpectedEOF + if n < len(buf) { + err = io.ErrUnexpectedEOF + } else { + // Readers are allowed to give EOF even though the read succeeded. + // In such cases, we discard the EOF, like io.ReadFull() does. + err = nil + } } return err } +// readByte reads a single byte from the underlying stream. func (s *Stream) readByte() (byte, error) { if err := s.willRead(1); err != nil { return 0, err @@ -1020,16 +1155,16 @@ func (s *Stream) readByte() (byte, error) { return b, err } +// willRead is called before any read from the underlying stream. It checks +// n against size limits, and updates the limits if n doesn't overflow them. func (s *Stream) willRead(n uint64) error { s.kind = -1 // rearm Kind - if len(s.stack) > 0 { - // check list overflow - tos := s.stack[len(s.stack)-1] - if n > tos.size-tos.pos { + if inList, limit := s.listLimit(); inList { + if n > limit { return ErrElemTooLarge } - s.stack[len(s.stack)-1].pos += n + s.stack[len(s.stack)-1] = limit - n } if s.limited { if n > s.remaining { @@ -1039,3 +1174,11 @@ func (s *Stream) willRead(n uint64) error { } return nil } + +// listLimit returns the amount of data remaining in the innermost list. +func (s *Stream) listLimit() (inList bool, limit uint64) { + if len(s.stack) == 0 { + return false, 0 + } + return true, s.stack[len(s.stack)-1] +} diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd0012..3ee237fb09 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -26,6 +26,10 @@ import ( "reflect" "strings" "testing" + + "github.com/tomochain/tomochain/common/math" + + "github.com/holiman/uint256" ) func TestStreamKind(t *testing.T) { @@ -284,6 +288,47 @@ func TestStreamRaw(t *testing.T) { } } +func TestStreamReadBytes(t *testing.T) { + tests := []struct { + input string + size int + err string + }{ + // kind List + {input: "C0", size: 1, err: "rlp: expected String or Byte"}, + // kind Byte + {input: "04", size: 0, err: "input value has wrong size 1, want 0"}, + {input: "04", size: 1}, + {input: "04", size: 2, err: "input value has wrong size 1, want 2"}, + // kind String + {input: "820102", size: 0, err: "input value has wrong size 2, want 0"}, + {input: "820102", size: 1, err: "input value has wrong size 2, want 1"}, + {input: "820102", size: 2}, + {input: "820102", size: 3, err: "input value has wrong size 2, want 3"}, + } + + for _, test := range tests { + test := test + name := fmt.Sprintf("input_%s/size_%d", test.input, test.size) + t.Run(name, func(t *testing.T) { + s := NewStream(bytes.NewReader(unhex(test.input)), 0) + b := make([]byte, test.size) + err := s.ReadBytes(b) + if test.err == "" { + if err != nil { + t.Errorf("unexpected error %q", err) + } + } else { + if err == nil { + t.Errorf("expected error, got nil") + } else if err.Error() != test.err { + t.Errorf("wrong error %q", err) + } + } + }) + } +} + func TestDecodeErrors(t *testing.T) { r := bytes.NewReader(nil) @@ -327,6 +372,15 @@ type recstruct struct { Child *recstruct `rlp:"nil"` } +type bigIntStruct struct { + I *big.Int + B string +} + +type invalidNilTag struct { + X []byte `rlp:"nil"` +} + type invalidTail1 struct { A uint `rlp:"tail"` B string @@ -347,19 +401,79 @@ type tailUint struct { Tail []uint `rlp:"tail"` } -var ( - veryBigInt = big.NewInt(0).Add( - big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), - big.NewInt(0xFFFF), - ) -) +type tailPrivateFields struct { + A uint + Tail []uint `rlp:"tail"` + x, y bool //lint:ignore U1000 unused fields required for testing purposes. +} + +type nilListUint struct { + X *uint `rlp:"nilList"` +} + +type nilStringSlice struct { + X *[]uint `rlp:"nilString"` +} + +type intField struct { + X int +} + +type optionalFields struct { + A uint + B uint `rlp:"optional"` + C uint `rlp:"optional"` +} + +type optionalAndTailField struct { + A uint + B uint `rlp:"optional"` + Tail []uint `rlp:"tail"` +} + +type optionalBigIntField struct { + A uint + B *big.Int `rlp:"optional"` +} + +type optionalPtrField struct { + A uint + B *[3]byte `rlp:"optional"` +} + +type nonOptionalPtrField struct { + A uint + B *[3]byte +} -type hasIgnoredField struct { +type multipleOptionalFields struct { + A *[3]byte `rlp:"optional"` + B *[3]byte `rlp:"optional"` +} + +type optionalPtrFieldNil struct { + A uint + B *[3]byte `rlp:"optional,nil"` +} + +type ignoredField struct { A uint B uint `rlp:"-"` C uint } +var ( + veryBigInt = new(big.Int).Add( + new(big.Int).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), + big.NewInt(0xFFFF), + ) + veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil) +) + +var ( + veryBigInt256, _ = uint256.FromBig(veryBigInt) +) + var decodeTests = []decodeTest{ // booleans {input: "01", ptr: new(bool), value: true}, @@ -428,12 +542,31 @@ var decodeTests = []decodeTest{ {input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"}, // big ints + {input: "80", ptr: new(*big.Int), value: big.NewInt(0)}, {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, + {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works + + // big int errors {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, - {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, - {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"}, + {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"}, + + // uint256 + {input: "80", ptr: new(*uint256.Int), value: uint256.NewInt(0)}, + {input: "01", ptr: new(*uint256.Int), value: uint256.NewInt(1)}, + {input: "88FFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: uint256.NewInt(math.MaxUint64)}, + {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: veryBigInt256}, + {input: "10", ptr: new(uint256.Int), value: *uint256.NewInt(16)}, // non-pointer also works + + // uint256 errors + {input: "C0", ptr: new(*uint256.Int), error: "rlp: expected input string or byte for *uint256.Int"}, + {input: "00", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "820001", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "8105", ptr: new(*uint256.Int), error: "rlp: non-canonical size information for *uint256.Int"}, + {input: "A1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00", ptr: new(*uint256.Int), error: "rlp: value too large for uint256"}, // structs { @@ -446,6 +579,13 @@ var decodeTests = []decodeTest{ ptr: new(recstruct), value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, }, + { + // This checks that empty big.Int works correctly in struct context. It's easy to + // miss the update of s.kind for this case, so it needs its own test. + input: "C58083343434", + ptr: new(bigIntStruct), + value: bigIntStruct{new(big.Int), "444"}, + }, // struct errors { @@ -479,20 +619,20 @@ var decodeTests = []decodeTest{ error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I", }, { - input: "C0", - ptr: new(invalidTail1), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)", - }, - { - input: "C0", - ptr: new(invalidTail2), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)", + input: "C103", + ptr: new(intField), + error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)", }, { input: "C50102C20102", ptr: new(tailUint), error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]", }, + { + input: "C0", + ptr: new(invalidNilTag), + error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`, + }, // struct tag "tail" { @@ -510,12 +650,192 @@ var decodeTests = []decodeTest{ ptr: new(tailRaw), value: tailRaw{A: 1, Tail: []RawValue{}}, }, + { + input: "C3010203", + ptr: new(tailPrivateFields), + value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, + }, + { + input: "C0", + ptr: new(invalidTail1), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`, + }, + { + input: "C0", + ptr: new(invalidTail2), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`, + }, // struct tag "-" { input: "C20102", - ptr: new(hasIgnoredField), - value: hasIgnoredField{A: 1, C: 2}, + ptr: new(ignoredField), + value: ignoredField{A: 1, C: 2}, + }, + + // struct tag "nilList" + { + input: "C180", + ptr: new(nilListUint), + error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X", + }, + { + input: "C1C0", + ptr: new(nilListUint), + value: nilListUint{}, + }, + { + input: "C103", + ptr: new(nilListUint), + value: func() interface{} { + v := uint(3) + return nilListUint{X: &v} + }(), + }, + + // struct tag "nilString" + { + input: "C1C0", + ptr: new(nilStringSlice), + error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X", + }, + { + input: "C180", + ptr: new(nilStringSlice), + value: nilStringSlice{}, + }, + { + input: "C2C103", + ptr: new(nilStringSlice), + value: nilStringSlice{X: &[]uint{3}}, + }, + + // struct tag "optional" + { + input: "C101", + ptr: new(optionalFields), + value: optionalFields{1, 0, 0}, + }, + { + input: "C20102", + ptr: new(optionalFields), + value: optionalFields{1, 2, 0}, + }, + { + input: "C3010203", + ptr: new(optionalFields), + value: optionalFields{1, 2, 3}, + }, + { + input: "C401020304", + ptr: new(optionalFields), + error: "rlp: input list has too many elements for rlp.optionalFields", + }, + { + input: "C101", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C401020304", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}}, + }, + { + input: "C101", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: nil}, + }, + { + input: "C20102", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: big.NewInt(2)}, + }, + { + input: "C101", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1}, + }, + { + input: "C20180", // not accepted because "optional" doesn't enable "nil" + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C20102", + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C50183010203", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}}, + }, + { + // all optional fields nil + input: "C0", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: nil, B: nil}, + }, + { + // all optional fields set + input: "C88301020383010203", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, + }, + { + // nil optional field appears before a non-nil one + input: "C58083010203", + ptr: new(multipleOptionalFields), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.multipleOptionalFields).A", + }, + { + // decode a nil ptr into a ptr that is not nil or not optional + input: "C20180", + ptr: new(nonOptionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.nonOptionalPtrField).B", + }, + { + input: "C101", + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20180", // accepted because "nil" tag allows empty input + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalPtrFieldNil), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B", + }, + + // struct tag "optional" field clearing + { + input: "C101", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 0, C: 0}, + }, + { + input: "C20102", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 2, C: 0}, + }, + { + input: "C20102", + ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}}, + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C101", + ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}}, + value: optionalPtrField{A: 1}, }, // RawValue @@ -591,6 +911,26 @@ func TestDecodeWithByteReader(t *testing.T) { }) } +func testDecodeWithEncReader(t *testing.T, n int) { + s := strings.Repeat("0", n) + _, r, _ := EncodeToReader(s) + var decoded string + err := Decode(r, &decoded) + if err != nil { + t.Errorf("Unexpected decode error with n=%v: %v", n, err) + } + if decoded != s { + t.Errorf("Decode mismatch with n=%v", n) + } +} + +// This is a regression test checking that decoding from encReader +// works for RLP values of size 8192 bytes or more. +func TestDecodeWithEncReader(t *testing.T) { + testDecodeWithEncReader(t, 8188) // length with header is 8191 + testDecodeWithEncReader(t, 8189) // length with header is 8192 +} + // plainReader reads from a byte slice but does not // implement ReadByte. It is also not recognized by the // size validation. This is useful to test how the decoder @@ -661,6 +1001,22 @@ func TestDecodeDecoder(t *testing.T) { } } +func TestDecodeDecoderNilPointer(t *testing.T) { + var s struct { + T1 *testDecoder `rlp:"nil"` + T2 *testDecoder + } + if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil { + t.Fatalf("Decode error: %v", err) + } + if s.T1 != nil { + t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called) + } + if s.T2 == nil || !s.T2.called { + t.Errorf("decoder T2 not allocated/called") + } +} + type byteDecoder byte func (bd *byteDecoder) DecodeRLP(s *Stream) error { @@ -691,13 +1047,66 @@ func TestDecoderInByteSlice(t *testing.T) { } } +type unencodableDecoder func() + +func (f *unencodableDecoder) DecodeRLP(s *Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ListEnd(); err != nil { + return err + } + *f = func() {} + return nil +} + +func TestDecoderFunc(t *testing.T) { + var x func() + if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil { + t.Fatal(err) + } + x() +} + +// This tests the validity checks for fields with struct tag "optional". +func TestInvalidOptionalField(t *testing.T) { + type ( + invalid1 struct { + A uint `rlp:"optional"` + B uint + } + invalid2 struct { + T []uint `rlp:"tail,optional"` + } + invalid3 struct { + T []uint `rlp:"optional,tail"` + } + ) + + tests := []struct { + v interface{} + err string + }{ + {v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`}, + {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`}, + {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`}, + } + for _, test := range tests { + err := DecodeBytes(unhex("C20102"), test.v) + if err == nil { + t.Errorf("no error for %T", test.v) + } else if err.Error() != test.err { + t.Errorf("wrong error for %T: %v", test.v, err.Error()) + } + } +} + func ExampleDecode() { input, _ := hex.DecodeString("C90A1486666F6F626172") type example struct { - A, B uint - private uint // private fields are ignored - String string + A, B uint + String string } var s example @@ -708,7 +1117,7 @@ func ExampleDecode() { fmt.Printf("Decoded value: %#v\n", s) } // Output: - // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"} + // Decoded value: rlp.example{A:0xa, B:0x14, String:"foobar"} } func ExampleDecode_structTagNil() { @@ -768,7 +1177,7 @@ func ExampleStream() { // [102 111 111 98 97 114] } -func BenchmarkDecode(b *testing.B) { +func BenchmarkDecodeUints(b *testing.B) { enc := encodeTestSlice(90000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -783,7 +1192,7 @@ func BenchmarkDecode(b *testing.B) { } } -func BenchmarkDecodeIntSliceReuse(b *testing.B) { +func BenchmarkDecodeUintsReused(b *testing.B) { enc := encodeTestSlice(100000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -798,6 +1207,65 @@ func BenchmarkDecodeIntSliceReuse(b *testing.B) { } } +func BenchmarkDecodeByteArrayStruct(b *testing.B) { + enc, err := EncodeToBytes(&byteArrayStruct{}) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out byteArrayStruct + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*big.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*uint256.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + func encodeTestSlice(n uint) []byte { s := make([]uint, n) for i := uint(0); i < n; i++ { @@ -811,7 +1279,7 @@ func encodeTestSlice(n uint) []byte { } func unhex(str string) []byte { - b, err := hex.DecodeString(strings.Replace(str, " ", "", -1)) + b, err := hex.DecodeString(strings.ReplaceAll(str, " ", "")) if err != nil { panic(fmt.Sprintf("invalid hex string: %q", str)) } diff --git a/rlp/doc.go b/rlp/doc.go index b3a81fe232..eeeee9a43a 100644 --- a/rlp/doc.go +++ b/rlp/doc.go @@ -17,17 +17,142 @@ /* Package rlp implements the RLP serialization format. -The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily -nested arrays of binary data, and RLP is the main encoding method used -to serialize objects in Ethereum. The only purpose of RLP is to encode -structure; encoding specific atomic data types (eg. strings, ints, -floats) is left up to higher-order protocols; in Ethereum integers -must be represented in big endian binary form with no leading zeroes -(thus making the integer value zero equivalent to the empty byte -array). - -RLP values are distinguished by a type tag. The type tag precedes the -value in the input stream and defines the size and kind of the bytes -that follow. +The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of +binary data, and RLP is the main encoding method used to serialize objects in Ethereum. +The only purpose of RLP is to encode structure; encoding specific atomic data types (eg. +strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be +represented in big endian binary form with no leading zeroes (thus making the integer +value zero equivalent to the empty string). + +RLP values are distinguished by a type tag. The type tag precedes the value in the input +stream and defines the size and kind of the bytes that follow. + +# Encoding Rules + +Package rlp uses reflection and encodes RLP based on the Go type of the value. + +If the type implements the Encoder interface, Encode calls EncodeRLP. It does not +call EncodeRLP on nil pointer values. + +To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct +type, slice or array always encodes as an empty RLP list unless the slice or array has +element type byte. A nil pointer to any other value encodes as the empty string. + +Struct values are encoded as an RLP list of all their encoded public fields. Recursive +struct types are supported. + +To encode slices and arrays, the elements are encoded as an RLP list of the value's +elements. Note that arrays and slices with element type uint8 or byte are always encoded +as an RLP string. + +A Go string is encoded as an RLP string. + +An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP +string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...) +are not supported and will return an error when encoding. + +Boolean values are encoded as the unsigned integers zero (false) and one (true). + +An interface value encodes as the value contained in the interface. + +Floating point numbers, maps, channels and functions are not supported. + +# Decoding Rules + +Decoding uses the following type-dependent rules: + +If the type implements the Decoder interface, DecodeRLP is called. + +To decode into a pointer, the value will be decoded as the element type of the pointer. If +the pointer is nil, a new value of the pointer's element type is allocated. If the pointer +is non-nil, the existing value will be reused. Note that package rlp never leaves a +pointer-type struct field as nil unless one of the "nil" struct tags is present. + +To decode into a struct, decoding expects the input to be an RLP list. The decoded +elements of the list are assigned to each public field in the order given by the struct's +definition. The input list must contain an element for each decoded field. Decoding +returns an error if there are too few or too many elements for the struct. + +To decode into a slice, the input must be a list and the resulting slice will contain the +input elements in order. For byte slices, the input must be an RLP string. Array types +decode similarly, with the additional restriction that the number of input elements (or +bytes) must match the array's defined length. + +To decode into a Go string, the input must be an RLP string. The input bytes are taken +as-is and will not necessarily be valid UTF-8. + +To decode into an unsigned integer type, the input must also be an RLP string. The bytes +are interpreted as a big endian representation of the integer. If the RLP string is larger +than the bit size of the type, decoding will return an error. Decode also supports +*big.Int. There is no size limit for big integers. + +To decode into a boolean, the input must contain an unsigned integer of value zero (false) +or one (true). + +To decode into an interface value, one of these types is stored in the value: + + []interface{}, for RLP lists + []byte, for RLP strings + +Non-empty interface types are not supported when decoding. +Signed integers, floating point numbers, maps, channels and functions cannot be decoded into. + +# Struct Tags + +As with other encoding packages, the "-" tag ignores fields. + + type StructWithIgnoredField struct{ + Ignored uint `rlp:"-"` + Field uint + } + +Go struct values encode/decode as RLP lists. There are two ways of influencing the mapping +of fields to list elements. The "tail" tag, which may only be used on the last exported +struct field, allows slurping up any excess list elements into a slice. + + type StructWithTail struct{ + Field uint + Tail []string `rlp:"tail"` + } + +The "optional" tag says that the field may be omitted if it is zero-valued. If this tag is +used on a struct field, all subsequent public fields must also be declared optional. + +When encoding a struct with optional fields, the output RLP list contains all values up to +the last non-zero optional field. + +When decoding into a struct, optional fields may be omitted from the end of the input +list. For the example below, this means input lists of one, two, or three elements are +accepted. + + type StructWithOptionalFields struct{ + Required uint + Optional1 uint `rlp:"optional"` + Optional2 uint `rlp:"optional"` + } + +The "nil", "nilList" and "nilString" tags apply to pointer-typed fields only, and change +the decoding rules for the field type. For regular pointer fields without the "nil" tag, +input values must always match the required input length exactly and the decoder does not +produce nil values. When the "nil" tag is set, input values of size zero decode as a nil +pointer. This is especially useful for recursive types. + + type StructWithNilField struct { + Field *[3]byte `rlp:"nil"` + } + +In the example above, Field allows two possible input sizes. For input 0xC180 (a list +containing an empty string) Field is set to nil after decoding. For input 0xC483000000 (a +list containing a 3-byte string), Field is set to a non-nil array pointer. + +RLP supports two kinds of empty values: empty lists and empty strings. When using the +"nil" tag, the kind of empty value allowed for a type is chosen automatically. A field +whose Go type is a pointer to an unsigned integer, string, boolean or byte array/slice +expects an empty RLP string. Any other pointer field type encodes/decodes as an empty RLP +list. + +The choice of null value can be made explicit with the "nilList" and "nilString" struct +tags. Using these tags encodes/decodes a Go nil pointer value as the empty RLP value kind +defined by the tag. */ package rlp diff --git a/rlp/encbuffer.go b/rlp/encbuffer.go new file mode 100644 index 0000000000..8d3a3b2293 --- /dev/null +++ b/rlp/encbuffer.go @@ -0,0 +1,423 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "encoding/binary" + "io" + "math/big" + "reflect" + "sync" + + "github.com/holiman/uint256" +) + +type encBuffer struct { + str []byte // string data, contains everything except list headers + lheads []listhead // all list headers + lhsize int // sum of sizes of all encoded list headers + sizebuf [9]byte // auxiliary buffer for uint encoding +} + +// The global encBuffer pool. +var encBufferPool = sync.Pool{ + New: func() interface{} { return new(encBuffer) }, +} + +func getEncBuffer() *encBuffer { + buf := encBufferPool.Get().(*encBuffer) + buf.reset() + return buf +} + +func (buf *encBuffer) reset() { + buf.lhsize = 0 + buf.str = buf.str[:0] + buf.lheads = buf.lheads[:0] +} + +// size returns the length of the encoded data. +func (buf *encBuffer) size() int { + return len(buf.str) + buf.lhsize +} + +// makeBytes creates the encoder output. +func (buf *encBuffer) makeBytes() []byte { + out := make([]byte, buf.size()) + buf.copyTo(out) + return out +} + +func (buf *encBuffer) copyTo(dst []byte) { + strpos := 0 + pos := 0 + for _, head := range buf.lheads { + // write string data before header + n := copy(dst[pos:], buf.str[strpos:head.offset]) + pos += n + strpos += n + // write the header + enc := head.encode(dst[pos:]) + pos += len(enc) + } + // copy string data after the last list header + copy(dst[pos:], buf.str[strpos:]) +} + +// writeTo writes the encoder output to w. +func (buf *encBuffer) writeTo(w io.Writer) (err error) { + strpos := 0 + for _, head := range buf.lheads { + // write string data before header + if head.offset-strpos > 0 { + n, err := w.Write(buf.str[strpos:head.offset]) + strpos += n + if err != nil { + return err + } + } + // write the header + enc := head.encode(buf.sizebuf[:]) + if _, err = w.Write(enc); err != nil { + return err + } + } + if strpos < len(buf.str) { + // write string data after the last list header + _, err = w.Write(buf.str[strpos:]) + } + return err +} + +// Write implements io.Writer and appends b directly to the output. +func (buf *encBuffer) Write(b []byte) (int, error) { + buf.str = append(buf.str, b...) + return len(b), nil +} + +// writeBool writes b as the integer 0 (false) or 1 (true). +func (buf *encBuffer) writeBool(b bool) { + if b { + buf.str = append(buf.str, 0x01) + } else { + buf.str = append(buf.str, 0x80) + } +} + +func (buf *encBuffer) writeUint64(i uint64) { + if i == 0 { + buf.str = append(buf.str, 0x80) + } else if i < 128 { + // fits single byte + buf.str = append(buf.str, byte(i)) + } else { + s := putint(buf.sizebuf[1:], i) + buf.sizebuf[0] = 0x80 + byte(s) + buf.str = append(buf.str, buf.sizebuf[:s+1]...) + } +} + +func (buf *encBuffer) writeBytes(b []byte) { + if len(b) == 1 && b[0] <= 0x7F { + // fits single byte, no string header + buf.str = append(buf.str, b[0]) + } else { + buf.encodeStringHeader(len(b)) + buf.str = append(buf.str, b...) + } +} + +func (buf *encBuffer) writeString(s string) { + buf.writeBytes([]byte(s)) +} + +// wordBytes is the number of bytes in a big.Word +const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8 + +// writeBigInt writes i as an integer. +func (buf *encBuffer) writeBigInt(i *big.Int) { + bitlen := i.BitLen() + if bitlen <= 64 { + buf.writeUint64(i.Uint64()) + return + } + // Integer is larger than 64 bits, encode from i.Bits(). + // The minimal byte length is bitlen rounded up to the next + // multiple of 8, divided by 8. + length := ((bitlen + 7) & -8) >> 3 + buf.encodeStringHeader(length) + buf.str = append(buf.str, make([]byte, length)...) + index := length + bytesBuf := buf.str[len(buf.str)-length:] + for _, d := range i.Bits() { + for j := 0; j < wordBytes && index > 0; j++ { + index-- + bytesBuf[index] = byte(d) + d >>= 8 + } + } +} + +// writeUint256 writes z as an integer. +func (buf *encBuffer) writeUint256(z *uint256.Int) { + bitlen := z.BitLen() + if bitlen <= 64 { + buf.writeUint64(z.Uint64()) + return + } + nBytes := byte((bitlen + 7) / 8) + var b [33]byte + binary.BigEndian.PutUint64(b[1:9], z[3]) + binary.BigEndian.PutUint64(b[9:17], z[2]) + binary.BigEndian.PutUint64(b[17:25], z[1]) + binary.BigEndian.PutUint64(b[25:33], z[0]) + b[32-nBytes] = 0x80 + nBytes + buf.str = append(buf.str, b[32-nBytes:]...) +} + +// list adds a new list header to the header stack. It returns the index of the header. +// Call listEnd with this index after encoding the content of the list. +func (buf *encBuffer) list() int { + buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize}) + return len(buf.lheads) - 1 +} + +func (buf *encBuffer) listEnd(index int) { + lh := &buf.lheads[index] + lh.size = buf.size() - lh.offset - lh.size + if lh.size < 56 { + buf.lhsize++ // length encoded into kind tag + } else { + buf.lhsize += 1 + intsize(uint64(lh.size)) + } +} + +func (buf *encBuffer) encode(val interface{}) error { + rval := reflect.ValueOf(val) + writer, err := cachedWriter(rval.Type()) + if err != nil { + return err + } + return writer(rval, buf) +} + +func (buf *encBuffer) encodeStringHeader(size int) { + if size < 56 { + buf.str = append(buf.str, 0x80+byte(size)) + } else { + sizesize := putint(buf.sizebuf[1:], uint64(size)) + buf.sizebuf[0] = 0xB7 + byte(sizesize) + buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...) + } +} + +// encReader is the io.Reader returned by EncodeToReader. +// It releases its encbuf at EOF. +type encReader struct { + buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF. + lhpos int // index of list header that we're reading + strpos int // current position in string buffer + piece []byte // next piece to be read +} + +func (r *encReader) Read(b []byte) (n int, err error) { + for { + if r.piece = r.next(); r.piece == nil { + // Put the encode buffer back into the pool at EOF when it + // is first encountered. Subsequent calls still return EOF + // as the error but the buffer is no longer valid. + if r.buf != nil { + encBufferPool.Put(r.buf) + r.buf = nil + } + return n, io.EOF + } + nn := copy(b[n:], r.piece) + n += nn + if nn < len(r.piece) { + // piece didn't fit, see you next time. + r.piece = r.piece[nn:] + return n, nil + } + r.piece = nil + } +} + +// next returns the next piece of data to be read. +// it returns nil at EOF. +func (r *encReader) next() []byte { + switch { + case r.buf == nil: + return nil + + case r.piece != nil: + // There is still data available for reading. + return r.piece + + case r.lhpos < len(r.buf.lheads): + // We're before the last list header. + head := r.buf.lheads[r.lhpos] + sizebefore := head.offset - r.strpos + if sizebefore > 0 { + // String data before header. + p := r.buf.str[r.strpos:head.offset] + r.strpos += sizebefore + return p + } + r.lhpos++ + return head.encode(r.buf.sizebuf[:]) + + case r.strpos < len(r.buf.str): + // String data at the end, after all list headers. + p := r.buf.str[r.strpos:] + r.strpos = len(r.buf.str) + return p + + default: + return nil + } +} + +func encBufferFromWriter(w io.Writer) *encBuffer { + switch w := w.(type) { + case EncoderBuffer: + return w.buf + case *EncoderBuffer: + return w.buf + case *encBuffer: + return w + default: + return nil + } +} + +// EncoderBuffer is a buffer for incremental encoding. +// +// The zero value is NOT ready for use. To get a usable buffer, +// create it using NewEncoderBuffer or call Reset. +type EncoderBuffer struct { + buf *encBuffer + dst io.Writer + + ownBuffer bool +} + +// NewEncoderBuffer creates an encoder buffer. +func NewEncoderBuffer(dst io.Writer) EncoderBuffer { + var w EncoderBuffer + w.Reset(dst) + return w +} + +// Reset truncates the buffer and sets the output destination. +func (w *EncoderBuffer) Reset(dst io.Writer) { + if w.buf != nil && !w.ownBuffer { + panic("can't Reset derived EncoderBuffer") + } + + // If the destination writer has an *encBuffer, use it. + // Note that w.ownBuffer is left false here. + if dst != nil { + if outer := encBufferFromWriter(dst); outer != nil { + *w = EncoderBuffer{outer, nil, false} + return + } + } + + // Get a fresh buffer. + if w.buf == nil { + w.buf = encBufferPool.Get().(*encBuffer) + w.ownBuffer = true + } + w.buf.reset() + w.dst = dst +} + +// Flush writes encoded RLP data to the output writer. This can only be called once. +// If you want to re-use the buffer after Flush, you must call Reset. +func (w *EncoderBuffer) Flush() error { + var err error + if w.dst != nil { + err = w.buf.writeTo(w.dst) + } + // Release the internal buffer. + if w.ownBuffer { + encBufferPool.Put(w.buf) + } + *w = EncoderBuffer{} + return err +} + +// ToBytes returns the encoded bytes. +func (w *EncoderBuffer) ToBytes() []byte { + return w.buf.makeBytes() +} + +// AppendToBytes appends the encoded bytes to dst. +func (w *EncoderBuffer) AppendToBytes(dst []byte) []byte { + size := w.buf.size() + out := append(dst, make([]byte, size)...) + w.buf.copyTo(out[len(dst):]) + return out +} + +// Write appends b directly to the encoder output. +func (w EncoderBuffer) Write(b []byte) (int, error) { + return w.buf.Write(b) +} + +// WriteBool writes b as the integer 0 (false) or 1 (true). +func (w EncoderBuffer) WriteBool(b bool) { + w.buf.writeBool(b) +} + +// WriteUint64 encodes an unsigned integer. +func (w EncoderBuffer) WriteUint64(i uint64) { + w.buf.writeUint64(i) +} + +// WriteBigInt encodes a big.Int as an RLP string. +// Note: Unlike with Encode, the sign of i is ignored. +func (w EncoderBuffer) WriteBigInt(i *big.Int) { + w.buf.writeBigInt(i) +} + +// WriteUint256 encodes uint256.Int as an RLP string. +func (w EncoderBuffer) WriteUint256(i *uint256.Int) { + w.buf.writeUint256(i) +} + +// WriteBytes encodes b as an RLP string. +func (w EncoderBuffer) WriteBytes(b []byte) { + w.buf.writeBytes(b) +} + +// WriteString encodes s as an RLP string. +func (w EncoderBuffer) WriteString(s string) { + w.buf.writeString(s) +} + +// List starts a list. It returns an internal index. Call EndList with +// this index after encoding the content to finish the list. +func (w EncoderBuffer) List() int { + return w.buf.list() +} + +// ListEnd finishes the given list. +func (w EncoderBuffer) ListEnd(index int) { + w.buf.listEnd(index) +} diff --git a/rlp/encbuffer_example_test.go b/rlp/encbuffer_example_test.go new file mode 100644 index 0000000000..c41de60f02 --- /dev/null +++ b/rlp/encbuffer_example_test.go @@ -0,0 +1,45 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp_test + +import ( + "bytes" + "fmt" + + "github.com/tomochain/tomochain/rlp" +) + +func ExampleEncoderBuffer() { + var w bytes.Buffer + + // Encode [4, [5, 6]] to w. + buf := rlp.NewEncoderBuffer(&w) + l1 := buf.List() + buf.WriteUint64(4) + l2 := buf.List() + buf.WriteUint64(5) + buf.WriteUint64(6) + buf.ListEnd(l2) + buf.ListEnd(l1) + + if err := buf.Flush(); err != nil { + panic(err) + } + fmt.Printf("%X\n", w.Bytes()) + // Output: + // C404C20506 +} diff --git a/rlp/encode.go b/rlp/encode.go index 44592c2f53..2ca283c0a3 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -17,20 +17,28 @@ package rlp import ( + "errors" "fmt" "io" "math/big" "reflect" - "sync" + + "github.com/holiman/uint256" + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) var ( // Common encoded values. // These are useful when implementing EncodeRLP. + + // EmptyString is the encoding of an empty string. EmptyString = []byte{0x80} - EmptyList = []byte{0xC0} + // EmptyList is the encoding of an empty list. + EmptyList = []byte{0xC0} ) +var ErrNegativeBigInt = errors.New("rlp: cannot encode negative big.Int") + // Encoder is implemented by types that require custom // encoding rules or want to encode private fields. type Encoder interface { @@ -49,80 +57,48 @@ type Encoder interface { // perform many small writes in some cases. Consider making w // buffered. // -// Encode uses the following type-dependent encoding rules: -// -// If the type implements the Encoder interface, Encode calls -// EncodeRLP. This is true even for nil pointers, please see the -// documentation for Encoder. -// -// To encode a pointer, the value being pointed to is encoded. For nil -// pointers, Encode will encode the zero value of the type. A nil -// pointer to a struct type always encodes as an empty RLP list. -// A nil pointer to an array encodes as an empty list (or empty string -// if the array has element type byte). -// -// Struct values are encoded as an RLP list of all their encoded -// public fields. Recursive struct types are supported. -// -// To encode slices and arrays, the elements are encoded as an RLP -// list of the value's elements. Note that arrays and slices with -// element type uint8 or byte are always encoded as an RLP string. -// -// A Go string is encoded as an RLP string. -// -// An unsigned integer value is encoded as an RLP string. Zero always -// encodes as an empty RLP string. Encode also supports *big.Int. -// -// An interface value encodes as the value contained in the interface. -// -// Boolean values are not supported, nor are signed integers, floating -// point numbers, maps, channels and functions. +// Please see package-level documentation of encoding rules. func Encode(w io.Writer, val interface{}) error { - if outer, ok := w.(*encbuf); ok { - // Encode was called by some type's EncodeRLP. - // Avoid copying by writing to the outer encbuf directly. - return outer.encode(val) + // Optimization: reuse *encBuffer when called by EncodeRLP. + if buf := encBufferFromWriter(w); buf != nil { + return buf.encode(val) } - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + + buf := getEncBuffer() + defer encBufferPool.Put(buf) + if err := buf.encode(val); err != nil { return err } - return eb.toWriter(w) + return buf.writeTo(w) } -// EncodeBytes returns the RLP encoding of val. -// Please see the documentation of Encode for the encoding rules. +// EncodeToBytes returns the RLP encoding of val. +// Please see package-level documentation for the encoding rules. func EncodeToBytes(val interface{}) ([]byte, error) { - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + defer encBufferPool.Put(buf) + + if err := buf.encode(val); err != nil { return nil, err } - return eb.toBytes(), nil + return buf.makeBytes(), nil } -// EncodeReader returns a reader from which the RLP encoding of val +// EncodeToReader returns a reader from which the RLP encoding of val // can be read. The returned size is the total size of the encoded // data. // // Please see the documentation of Encode for the encoding rules. func EncodeToReader(val interface{}) (size int, r io.Reader, err error) { - eb := encbufPool.Get().(*encbuf) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + if err := buf.encode(val); err != nil { + encBufferPool.Put(buf) return 0, nil, err } - return eb.size(), &encReader{buf: eb}, nil -} - -type encbuf struct { - str []byte // string data, contains everything except list headers - lheads []*listhead // all list headers - lhsize int // sum of sizes of all encoded list headers - sizebuf []byte // 9-byte auxiliary buffer for uint encoding + // Note: can't put the reader back into the pool here + // because it is held by encReader. The reader puts it + // back when it has been fully consumed. + return buf.size(), &encReader{buf: buf}, nil } type listhead struct { @@ -151,214 +127,32 @@ func puthead(buf []byte, smalltag, largetag byte, size uint64) int { if size < 56 { buf[0] = smalltag + byte(size) return 1 - } else { - sizesize := putint(buf[1:], size) - buf[0] = largetag + byte(sizesize) - return sizesize + 1 - } -} - -// encbufs are pooled. -var encbufPool = sync.Pool{ - New: func() interface{} { return &encbuf{sizebuf: make([]byte, 9)} }, -} - -func (w *encbuf) reset() { - w.lhsize = 0 - if w.str != nil { - w.str = w.str[:0] - } - if w.lheads != nil { - w.lheads = w.lheads[:0] - } -} - -// encbuf implements io.Writer so it can be passed it into EncodeRLP. -func (w *encbuf) Write(b []byte) (int, error) { - w.str = append(w.str, b...) - return len(b), nil -} - -func (w *encbuf) encode(val interface{}) error { - rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type(), tags{}) - if err != nil { - return err - } - return ti.writer(rval, w) -} - -func (w *encbuf) encodeStringHeader(size int) { - if size < 56 { - w.str = append(w.str, 0x80+byte(size)) - } else { - // TODO: encode to w.str directly - sizesize := putint(w.sizebuf[1:], uint64(size)) - w.sizebuf[0] = 0xB7 + byte(sizesize) - w.str = append(w.str, w.sizebuf[:sizesize+1]...) - } -} - -func (w *encbuf) encodeString(b []byte) { - if len(b) == 1 && b[0] <= 0x7F { - // fits single byte, no string header - w.str = append(w.str, b[0]) - } else { - w.encodeStringHeader(len(b)) - w.str = append(w.str, b...) - } -} - -func (w *encbuf) list() *listhead { - lh := &listhead{offset: len(w.str), size: w.lhsize} - w.lheads = append(w.lheads, lh) - return lh -} - -func (w *encbuf) listEnd(lh *listhead) { - lh.size = w.size() - lh.offset - lh.size - if lh.size < 56 { - w.lhsize += 1 // length encoded into kind tag - } else { - w.lhsize += 1 + intsize(uint64(lh.size)) - } -} - -func (w *encbuf) size() int { - return len(w.str) + w.lhsize -} - -func (w *encbuf) toBytes() []byte { - out := make([]byte, w.size()) - strpos := 0 - pos := 0 - for _, head := range w.lheads { - // write string data before header - n := copy(out[pos:], w.str[strpos:head.offset]) - pos += n - strpos += n - // write the header - enc := head.encode(out[pos:]) - pos += len(enc) - } - // copy string data after the last list header - copy(out[pos:], w.str[strpos:]) - return out -} - -func (w *encbuf) toWriter(out io.Writer) (err error) { - strpos := 0 - for _, head := range w.lheads { - // write string data before header - if head.offset-strpos > 0 { - n, err := out.Write(w.str[strpos:head.offset]) - strpos += n - if err != nil { - return err - } - } - // write the header - enc := head.encode(w.sizebuf) - if _, err = out.Write(enc); err != nil { - return err - } - } - if strpos < len(w.str) { - // write string data after the last list header - _, err = out.Write(w.str[strpos:]) - } - return err -} - -// encReader is the io.Reader returned by EncodeToReader. -// It releases its encbuf at EOF. -type encReader struct { - buf *encbuf // the buffer we're reading from. this is nil when we're at EOF. - lhpos int // index of list header that we're reading - strpos int // current position in string buffer - piece []byte // next piece to be read -} - -func (r *encReader) Read(b []byte) (n int, err error) { - for { - if r.piece = r.next(); r.piece == nil { - // Put the encode buffer back into the pool at EOF when it - // is first encountered. Subsequent calls still return EOF - // as the error but the buffer is no longer valid. - if r.buf != nil { - encbufPool.Put(r.buf) - r.buf = nil - } - return n, io.EOF - } - nn := copy(b[n:], r.piece) - n += nn - if nn < len(r.piece) { - // piece didn't fit, see you next time. - r.piece = r.piece[nn:] - return n, nil - } - r.piece = nil - } -} - -// next returns the next piece of data to be read. -// it returns nil at EOF. -func (r *encReader) next() []byte { - switch { - case r.buf == nil: - return nil - - case r.piece != nil: - // There is still data available for reading. - return r.piece - - case r.lhpos < len(r.buf.lheads): - // We're before the last list header. - head := r.buf.lheads[r.lhpos] - sizebefore := head.offset - r.strpos - if sizebefore > 0 { - // String data before header. - p := r.buf.str[r.strpos:head.offset] - r.strpos += sizebefore - return p - } else { - r.lhpos++ - return head.encode(r.buf.sizebuf) - } - - case r.strpos < len(r.buf.str): - // String data at the end, after all list headers. - p := r.buf.str[r.strpos:] - r.strpos = len(r.buf.str) - return p - - default: - return nil } + sizesize := putint(buf[1:], size) + buf[0] = largetag + byte(sizesize) + return sizesize + 1 } -var ( - encoderInterface = reflect.TypeOf(new(Encoder)).Elem() - big0 = big.NewInt(0) -) +var encoderInterface = reflect.TypeOf(new(Encoder)).Elem() // makeWriter creates a writer function for the given type. -func makeWriter(typ reflect.Type, ts tags) (writer, error) { +func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { kind := typ.Kind() switch { case typ == rawValueType: return writeRawValue, nil - case typ.Implements(encoderInterface): - return writeEncoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(encoderInterface): - return writeEncoderNoPtr, nil - case kind == reflect.Interface: - return writeInterface, nil case typ.AssignableTo(reflect.PtrTo(bigInt)): return writeBigIntPtr, nil case typ.AssignableTo(bigInt): return writeBigIntNoPtr, nil + case typ == reflect.PtrTo(u256Int): + return writeU256IntPtr, nil + case typ == u256Int: + return writeU256IntNoPtr, nil + case kind == reflect.Ptr: + return makePtrWriter(typ, ts) + case reflect.PtrTo(typ).Implements(encoderInterface): + return makeEncoderWriter(typ), nil case isUint(kind): return writeUint, nil case kind == reflect.Bool: @@ -368,97 +162,116 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) { case kind == reflect.Slice && isByte(typ.Elem()): return writeBytes, nil case kind == reflect.Array && isByte(typ.Elem()): - return writeByteArray, nil + return makeByteArrayWriter(typ), nil case kind == reflect.Slice || kind == reflect.Array: return makeSliceWriter(typ, ts) case kind == reflect.Struct: return makeStructWriter(typ) - case kind == reflect.Ptr: - return makePtrWriter(typ) + case kind == reflect.Interface: + return writeInterface, nil default: return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ) } } -func isByte(typ reflect.Type) bool { - return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) -} - -func writeRawValue(val reflect.Value, w *encbuf) error { +func writeRawValue(val reflect.Value, w *encBuffer) error { w.str = append(w.str, val.Bytes()...) return nil } -func writeUint(val reflect.Value, w *encbuf) error { - i := val.Uint() - if i == 0 { - w.str = append(w.str, 0x80) - } else if i < 128 { - // fits single byte - w.str = append(w.str, byte(i)) - } else { - // TODO: encode int to w.str directly - s := putint(w.sizebuf[1:], i) - w.sizebuf[0] = 0x80 + byte(s) - w.str = append(w.str, w.sizebuf[:s+1]...) - } +func writeUint(val reflect.Value, w *encBuffer) error { + w.writeUint64(val.Uint()) return nil } -func writeBool(val reflect.Value, w *encbuf) error { - if val.Bool() { - w.str = append(w.str, 0x01) - } else { - w.str = append(w.str, 0x80) - } +func writeBool(val reflect.Value, w *encBuffer) error { + w.writeBool(val.Bool()) return nil } -func writeBigIntPtr(val reflect.Value, w *encbuf) error { +func writeBigIntPtr(val reflect.Value, w *encBuffer) error { ptr := val.Interface().(*big.Int) if ptr == nil { w.str = append(w.str, 0x80) return nil } - return writeBigInt(ptr, w) + if ptr.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(ptr) + return nil } -func writeBigIntNoPtr(val reflect.Value, w *encbuf) error { +func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error { i := val.Interface().(big.Int) - return writeBigInt(&i, w) + if i.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(&i) + return nil } -func writeBigInt(i *big.Int, w *encbuf) error { - if cmp := i.Cmp(big0); cmp == -1 { - return fmt.Errorf("rlp: cannot encode negative *big.Int") - } else if cmp == 0 { +func writeU256IntPtr(val reflect.Value, w *encBuffer) error { + ptr := val.Interface().(*uint256.Int) + if ptr == nil { w.str = append(w.str, 0x80) - } else { - w.encodeString(i.Bytes()) + return nil } + w.writeUint256(ptr) + return nil +} + +func writeU256IntNoPtr(val reflect.Value, w *encBuffer) error { + i := val.Interface().(uint256.Int) + w.writeUint256(&i) return nil } -func writeBytes(val reflect.Value, w *encbuf) error { - w.encodeString(val.Bytes()) +func writeBytes(val reflect.Value, w *encBuffer) error { + w.writeBytes(val.Bytes()) return nil } -func writeByteArray(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // Slice requires the value to be addressable. - // Make it addressable by copying. - copy := reflect.New(val.Type()).Elem() - copy.Set(val) - val = copy +func makeByteArrayWriter(typ reflect.Type) writer { + switch typ.Len() { + case 0: + return writeLengthZeroByteArray + case 1: + return writeLengthOneByteArray + default: + length := typ.Len() + return func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // Getting the byte slice of val requires it to be addressable. Make it + // addressable by copying. + copy := reflect.New(val.Type()).Elem() + copy.Set(val) + val = copy + } + slice := byteArrayBytes(val, length) + w.encodeStringHeader(len(slice)) + w.str = append(w.str, slice...) + return nil + } } - size := val.Len() - slice := val.Slice(0, size).Bytes() - w.encodeString(slice) +} + +func writeLengthZeroByteArray(val reflect.Value, w *encBuffer) error { + w.str = append(w.str, 0x80) return nil } -func writeString(val reflect.Value, w *encbuf) error { +func writeLengthOneByteArray(val reflect.Value, w *encBuffer) error { + b := byte(val.Index(0).Uint()) + if b <= 0x7f { + w.str = append(w.str, b) + } else { + w.str = append(w.str, 0x81, b) + } + return nil +} + +func writeString(val reflect.Value, w *encBuffer) error { s := val.String() if len(s) == 1 && s[0] <= 0x7f { // fits single byte, no string header @@ -470,27 +283,7 @@ func writeString(val reflect.Value, w *encbuf) error { return nil } -func writeEncoder(val reflect.Value, w *encbuf) error { - return val.Interface().(Encoder).EncodeRLP(w) -} - -// writeEncoderNoPtr handles non-pointer values that implement Encoder -// with a pointer receiver. -func writeEncoderNoPtr(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // We can't get the address. It would be possible to make the - // value addressable by creating a shallow copy, but this - // creates other problems so we're not doing it (yet). - // - // package json simply doesn't call MarshalJSON for cases like - // this, but encodes the value as if it didn't implement the - // interface. We don't want to handle it that way. - return fmt.Errorf("rlp: game over: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) - } - return val.Addr().Interface().(Encoder).EncodeRLP(w) -} - -func writeInterface(val reflect.Value, w *encbuf) error { +func writeInterface(val reflect.Value, w *encBuffer) error { if val.IsNil() { // Write empty list. This is consistent with the previous RLP // encoder that we had and should therefore avoid any @@ -499,31 +292,51 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type(), tags{}) + writer, err := cachedWriter(eval.Type()) if err != nil { return err } - return ti.writer(eval, w) + return writer(eval, w) } -func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makeSliceWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } - writer := func(val reflect.Value, w *encbuf) error { - if !ts.tail { - defer w.listEnd(w.list()) + + var wfn writer + if ts.Tail { + // This is for struct tail slices. + // w.list is not called for them. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } + } + return nil } - vlen := val.Len() - for i := 0; i < vlen; i++ { - if err := etypeinfo.writer(val.Index(i), w); err != nil { - return err + } else { + // This is for regular slices and arrays. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + if vlen == 0 { + w.str = append(w.str, 0xC0) + return nil + } + listOffset := w.list() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } } + w.listEnd(listOffset) + return nil } - return nil } - return writer, nil + return wfn, nil } func makeStructWriter(typ reflect.Type) (writer, error) { @@ -531,56 +344,86 @@ func makeStructWriter(typ reflect.Type) (writer, error) { if err != nil { return nil, err } - writer := func(val reflect.Value, w *encbuf) error { - lh := w.list() - for _, f := range fields { - if err := f.info.writer(val.Field(f.index), w); err != nil { - return err + for _, f := range fields { + if f.info.writerErr != nil { + return nil, structFieldError{typ, f.index, f.info.writerErr} + } + } + + var writer writer + firstOptionalField := firstOptionalField(fields) + if firstOptionalField == len(fields) { + // This is the writer function for structs without any optional fields. + writer = func(val reflect.Value, w *encBuffer) error { + lh := w.list() + for _, f := range fields { + if err := f.info.writer(val.Field(f.index), w); err != nil { + return err + } } + w.listEnd(lh) + return nil + } + } else { + // If there are any "optional" fields, the writer needs to perform additional + // checks to determine the output list length. + writer = func(val reflect.Value, w *encBuffer) error { + lastField := len(fields) - 1 + for ; lastField >= firstOptionalField; lastField-- { + if !val.Field(fields[lastField].index).IsZero() { + break + } + } + lh := w.list() + for i := 0; i <= lastField; i++ { + if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil { + return err + } + } + w.listEnd(lh) + return nil } - w.listEnd(lh) - return nil } return writer, nil } -func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makePtrWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + nilEncoding := byte(0xC0) + if typeNilKind(typ.Elem(), ts) == String { + nilEncoding = 0x80 } - // determine nil pointer handler - var nilfunc func(*encbuf) error - kind := typ.Elem().Kind() - switch { - case kind == reflect.Array && isByte(typ.Elem().Elem()): - nilfunc = func(w *encbuf) error { - w.str = append(w.str, 0x80) - return nil - } - case kind == reflect.Struct || kind == reflect.Array: - nilfunc = func(w *encbuf) error { - // encoding the zero value of a struct/array could trigger - // infinite recursion, avoid that. - w.listEnd(w.list()) - return nil - } - default: - zero := reflect.Zero(typ.Elem()) - nilfunc = func(w *encbuf) error { - return etypeinfo.writer(zero, w) + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr + } + + writer := func(val reflect.Value, w *encBuffer) error { + if ev := val.Elem(); ev.IsValid() { + return etypeinfo.writer(ev, w) } + w.str = append(w.str, nilEncoding) + return nil } + return writer, nil +} - writer := func(val reflect.Value, w *encbuf) error { - if val.IsNil() { - return nilfunc(w) - } else { - return etypeinfo.writer(val.Elem(), w) +func makeEncoderWriter(typ reflect.Type) writer { + if typ.Implements(encoderInterface) { + return func(val reflect.Value, w *encBuffer) error { + return val.Interface().(Encoder).EncodeRLP(w) + } + } + w := func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // package json simply doesn't call MarshalJSON for this case, but encodes the + // value as if it didn't implement the interface. We don't want to handle it that + // way. + return fmt.Errorf("rlp: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) } + return val.Addr().Interface().(Encoder).EncodeRLP(w) } - return writer, err + return w } // putint writes i to the beginning of b in big endian byte diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 827960f7c1..9f2e6c38f9 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -21,10 +21,13 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math/big" + "runtime" "sync" "testing" + + "github.com/holiman/uint256" + "github.com/tomochain/tomochain/common/math" ) type testEncoder struct { @@ -33,12 +36,19 @@ type testEncoder struct { func (e *testEncoder) EncodeRLP(w io.Writer) error { if e == nil { - w.Write([]byte{0, 0, 0, 0}) - } else if e.err != nil { + panic("EncodeRLP called on nil value") + } + if e.err != nil { return e.err - } else { - w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) } + w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) + return nil +} + +type testEncoderValueMethod struct{} + +func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xFA, 0xFE, 0xF0}) return nil } @@ -49,6 +59,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { return nil } +type undecodableEncoder func() + +func (f undecodableEncoder) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xF5, 0xF5, 0xF5}) + return nil +} + type encodableReader struct { A, B uint } @@ -103,35 +120,95 @@ var encTests = []encTest{ {val: big.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, {val: big.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, { - val: big.NewInt(0).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + val: new(big.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), output: "8F102030405060708090A0B0C0D0E0F2", }, { - val: big.NewInt(0).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + val: new(big.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), output: "9C0100020003000400050006000700080009000A000B000C000D000E01", }, { - val: big.NewInt(0).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), + val: new(big.Int).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), output: "A1010000000000000000000000000000000000000000000000000000000000000000", }, + { + val: veryBigInt, + output: "89FFFFFFFFFFFFFFFFFF", + }, + { + val: veryVeryBigInt, + output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", + }, // non-pointer big.Int {val: *big.NewInt(0), output: "80"}, {val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"}, // negative ints are not supported - {val: big.NewInt(-1), error: "rlp: cannot encode negative *big.Int"}, - - // byte slices, strings + {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + + // uint256 + {val: uint256.NewInt(0), output: "80"}, + {val: uint256.NewInt(1), output: "01"}, + {val: uint256.NewInt(127), output: "7F"}, + {val: uint256.NewInt(128), output: "8180"}, + {val: uint256.NewInt(256), output: "820100"}, + {val: uint256.NewInt(1024), output: "820400"}, + {val: uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, + { + val: new(uint256.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + output: "8F102030405060708090A0B0C0D0E0F2", + }, + { + val: new(uint256.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + output: "9C0100020003000400050006000700080009000A000B000C000D000E01", + }, + // non-pointer uint256.Int + {val: *uint256.NewInt(0), output: "80"}, + {val: *uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + + // byte arrays + {val: [0]byte{}, output: "80"}, + {val: [1]byte{0}, output: "00"}, + {val: [1]byte{1}, output: "01"}, + {val: [1]byte{0x7F}, output: "7F"}, + {val: [1]byte{0x80}, output: "8180"}, + {val: [1]byte{0xFF}, output: "81FF"}, + {val: [3]byte{1, 2, 3}, output: "83010203"}, + {val: [57]byte{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // named byte type arrays + {val: [0]namedByteType{}, output: "80"}, + {val: [1]namedByteType{0}, output: "00"}, + {val: [1]namedByteType{1}, output: "01"}, + {val: [1]namedByteType{0x7F}, output: "7F"}, + {val: [1]namedByteType{0x80}, output: "8180"}, + {val: [1]namedByteType{0xFF}, output: "81FF"}, + {val: [3]namedByteType{1, 2, 3}, output: "83010203"}, + {val: [57]namedByteType{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // byte slices {val: []byte{}, output: "80"}, + {val: []byte{0}, output: "00"}, {val: []byte{0x7E}, output: "7E"}, {val: []byte{0x7F}, output: "7F"}, {val: []byte{0x80}, output: "8180"}, {val: []byte{1, 2, 3}, output: "83010203"}, + // named byte type slices + {val: []namedByteType{}, output: "80"}, + {val: []namedByteType{0}, output: "00"}, + {val: []namedByteType{0x7E}, output: "7E"}, + {val: []namedByteType{0x7F}, output: "7F"}, + {val: []namedByteType{0x80}, output: "8180"}, {val: []namedByteType{1, 2, 3}, output: "83010203"}, - {val: [...]namedByteType{1, 2, 3}, output: "83010203"}, + // strings {val: "", output: "80"}, {val: "\x7E", output: "7E"}, {val: "\x7F", output: "7F"}, @@ -204,6 +281,12 @@ var encTests = []encTest{ output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376", }, + // Non-byte arrays are encoded as lists. + // Note that it is important to test [4]uint64 specifically, + // because that's the underlying type of uint256.Int. + {val: [4]uint32{1, 2, 3, 4}, output: "C401020304"}, + {val: [4]uint64{1, 2, 3, 4}, output: "C401020304"}, + // RawValue {val: RawValue(unhex("01")), output: "01"}, {val: RawValue(unhex("82FFFF")), output: "82FFFF"}, @@ -214,11 +297,34 @@ var encTests = []encTest{ {val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"}, {val: &recstruct{5, nil}, output: "C205C0"}, {val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"}, + {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"}, + + // struct tag "-" + {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "tail" {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"}, {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"}, {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"}, {val: &tailRaw{A: 1, Tail: nil}, output: "C101"}, - {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "optional" + {val: &optionalFields{}, output: "C180"}, + {val: &optionalFields{A: 1}, output: "C101"}, + {val: &optionalFields{A: 1, B: 2}, output: "C20102"}, + {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"}, + {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"}, + {val: &optionalAndTailField{A: 1}, output: "C101"}, + {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalBigIntField{A: 1}, output: "C101"}, + {val: &optionalPtrField{A: 1}, output: "C101"}, + {val: &optionalPtrFieldNil{A: 1}, output: "C101"}, + {val: &multipleOptionalFields{A: nil, B: nil}, output: "C0"}, + {val: &multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, output: "C88301020383010203"}, + {val: &multipleOptionalFields{A: nil, B: &[3]byte{1, 2, 3}}, output: "C58083010203"}, // encodes without error but decode will fail + {val: &nonOptionalPtrField{A: 1}, output: "C20180"}, // encodes without error but decode will fail // nil {val: (*uint)(nil), output: "80"}, @@ -226,26 +332,73 @@ var encTests = []encTest{ {val: (*[]byte)(nil), output: "80"}, {val: (*[10]byte)(nil), output: "80"}, {val: (*big.Int)(nil), output: "80"}, + {val: (*uint256.Int)(nil), output: "80"}, {val: (*[]string)(nil), output: "C0"}, {val: (*[10]string)(nil), output: "C0"}, {val: (*[]interface{})(nil), output: "C0"}, {val: (*[]struct{ uint })(nil), output: "C0"}, {val: (*interface{})(nil), output: "C0"}, + // nil struct fields + { + val: struct { + X *[]byte + }{}, + output: "C180", + }, + { + val: struct { + X *[2]byte + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 `rlp:"nilList"` + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 `rlp:"nilString"` + }{}, + output: "C180", + }, + // interfaces {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct // Encoder - {val: (*testEncoder)(nil), output: "00000000"}, + {val: (*testEncoder)(nil), output: "C0"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, - // verify that pointer method testEncoder.EncodeRLP is called for + {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"}, + {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"}, + + // Verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "F5F5F5"}, + + // Verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, {val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"}, - // verify the error for non-addressable non-pointer Encoder - {val: testEncoder{}, error: "rlp: game over: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, - // verify the special case for []byte + + // Verify the error for non-addressable non-pointer Encoder. + {val: testEncoder{}, error: "rlp: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, + + // Verify Encoder takes precedence over []byte. {val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"}, } @@ -281,13 +434,28 @@ func TestEncodeToBytes(t *testing.T) { runEncTests(t, EncodeToBytes) } +func TestEncodeAppendToBytes(t *testing.T) { + buffer := make([]byte, 20) + runEncTests(t, func(val interface{}) ([]byte, error) { + w := NewEncoderBuffer(nil) + defer w.Flush() + + err := Encode(w, val) + if err != nil { + return nil, err + } + output := w.AppendToBytes(buffer[:0]) + return output, nil + }) +} + func TestEncodeToReader(t *testing.T) { runEncTests(t, func(val interface{}) ([]byte, error) { _, r, err := EncodeToReader(val) if err != nil { return nil, err } - return ioutil.ReadAll(r) + return io.ReadAll(r) }) } @@ -328,7 +496,7 @@ func TestEncodeToReaderReturnToPool(t *testing.T) { go func() { for i := 0; i < 1000; i++ { _, r, _ := EncodeToReader("foo") - ioutil.ReadAll(r) + io.ReadAll(r) r.Read(buf) r.Read(buf) r.Read(buf) @@ -339,3 +507,132 @@ func TestEncodeToReaderReturnToPool(t *testing.T) { } wg.Wait() } + +var sink interface{} + +func BenchmarkIntsize(b *testing.B) { + for i := 0; i < b.N; i++ { + sink = intsize(0x12345678) + } +} + +func BenchmarkPutint(b *testing.B) { + buf := make([]byte, 8) + for i := 0; i < b.N; i++ { + putint(buf, 0x12345678) + sink = buf + } +} + +func BenchmarkEncodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeConcurrentInterface(b *testing.B) { + type struct1 struct { + A string + B *big.Int + C [20]byte + } + value := []interface{}{ + uint(999), + &struct1{A: "hello", B: big.NewInt(0xFFFFFFFF)}, + [10]byte{1, 2, 3, 4, 5, 6}, + []string{"yeah", "yeah", "yeah"}, + } + + var wg sync.WaitGroup + for cpu := 0; cpu < runtime.NumCPU(); cpu++ { + wg.Add(1) + go func() { + defer wg.Done() + + var buffer bytes.Buffer + for i := 0; i < b.N; i++ { + buffer.Reset() + err := Encode(&buffer, value) + if err != nil { + panic(err) + } + } + }() + } + wg.Wait() +} + +type byteArrayStruct struct { + A [20]byte + B [32]byte + C [32]byte +} + +func BenchmarkEncodeByteArrayStruct(b *testing.B) { + var out bytes.Buffer + var value byteArrayStruct + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} + +type structSliceElem struct { + X uint64 + Y uint64 + Z uint64 +} + +type structPtrSlice []*structSliceElem + +func BenchmarkEncodeStructPtrSlice(b *testing.B) { + var out bytes.Buffer + var value = structPtrSlice{ + &structSliceElem{1, 1, 1}, + &structSliceElem{2, 2, 2}, + &structSliceElem{3, 3, 3}, + &structSliceElem{5, 5, 5}, + &structSliceElem{6, 6, 6}, + &structSliceElem{7, 7, 7}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} diff --git a/rlp/encoder_example_test.go b/rlp/encoder_example_test.go index 1cffa241c2..6291bfafe5 100644 --- a/rlp/encoder_example_test.go +++ b/rlp/encoder_example_test.go @@ -14,11 +14,13 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package rlp +package rlp_test import ( "fmt" "io" + + "github.com/tomochain/tomochain/rlp" ) type MyCoolType struct { @@ -28,27 +30,19 @@ type MyCoolType struct { // EncodeRLP writes x as RLP list [a, b] that omits the Name field. func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) { - // Note: the receiver can be a nil pointer. This allows you to - // control the encoding of nil, but it also means that you have to - // check for a nil receiver. - if x == nil { - err = Encode(w, []uint{0, 0}) - } else { - err = Encode(w, []uint{x.a, x.b}) - } - return err + return rlp.Encode(w, []uint{x.a, x.b}) } func ExampleEncoder() { var t *MyCoolType // t is nil pointer to MyCoolType - bytes, _ := EncodeToBytes(t) + bytes, _ := rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) t = &MyCoolType{Name: "foobar", a: 5, b: 6} - bytes, _ = EncodeToBytes(t) + bytes, _ = rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) // Output: - // → C28080 + // → C0 // &{foobar 5 6} → C20506 } diff --git a/rlp/internal/rlpstruct/rlpstruct.go b/rlp/internal/rlpstruct/rlpstruct.go new file mode 100644 index 0000000000..2e3eeb6881 --- /dev/null +++ b/rlp/internal/rlpstruct/rlpstruct.go @@ -0,0 +1,213 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package rlpstruct implements struct processing for RLP encoding/decoding. +// +// In particular, this package handles all rules around field filtering, +// struct tags and nil value determination. +package rlpstruct + +import ( + "fmt" + "reflect" + "strings" +) + +// Field represents a struct field. +type Field struct { + Name string + Index int + Exported bool + Type Type + Tag string +} + +// Type represents the attributes of a Go type. +type Type struct { + Name string + Kind reflect.Kind + IsEncoder bool // whether type implements rlp.Encoder + IsDecoder bool // whether type implements rlp.Decoder + Elem *Type // non-nil for Kind values of Ptr, Slice, Array +} + +// DefaultNilValue determines whether a nil pointer to t encodes/decodes +// as an empty string or empty list. +func (t Type) DefaultNilValue() NilKind { + k := t.Kind + if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) { + return NilKindString + } + return NilKindList +} + +// NilKind is the RLP value encoded in place of nil pointers. +type NilKind uint8 + +const ( + NilKindString NilKind = 0x80 + NilKindList NilKind = 0xC0 +) + +// Tags represents struct tags. +type Tags struct { + // rlp:"nil" controls whether empty input results in a nil pointer. + // nilKind is the kind of empty value allowed for the field. + NilKind NilKind + NilOK bool + + // rlp:"optional" allows for a field to be missing in the input list. + // If this is set, all subsequent fields must also be optional. + Optional bool + + // rlp:"tail" controls whether this field swallows additional list elements. It can + // only be set for the last field, which must be of slice type. + Tail bool + + // rlp:"-" ignores fields. + Ignored bool +} + +// TagError is raised for invalid struct tags. +type TagError struct { + StructType string + + // These are set by this package. + Field string + Tag string + Err string +} + +func (e TagError) Error() string { + field := "field " + e.Field + if e.StructType != "" { + field = e.StructType + "." + e.Field + } + return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err) +} + +// ProcessFields filters the given struct fields, returning only fields +// that should be considered for encoding/decoding. +func ProcessFields(allFields []Field) ([]Field, []Tags, error) { + lastPublic := lastPublicField(allFields) + + // Gather all exported fields and their tags. + var fields []Field + var tags []Tags + for _, field := range allFields { + if !field.Exported { + continue + } + ts, err := parseTag(field, lastPublic) + if err != nil { + return nil, nil, err + } + if ts.Ignored { + continue + } + fields = append(fields, field) + tags = append(tags, ts) + } + + // Verify optional field consistency. If any optional field exists, + // all fields after it must also be optional. Note: optional + tail + // is supported. + var anyOptional bool + var firstOptionalName string + for i, ts := range tags { + name := fields[i].Name + if ts.Optional || ts.Tail { + if !anyOptional { + firstOptionalName = name + } + anyOptional = true + } else { + if anyOptional { + msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName) + return nil, nil, TagError{Field: name, Err: msg} + } + } + } + return fields, tags, nil +} + +func parseTag(field Field, lastPublic int) (Tags, error) { + name := field.Name + tag := reflect.StructTag(field.Tag) + var ts Tags + for _, t := range strings.Split(tag.Get("rlp"), ",") { + switch t = strings.TrimSpace(t); t { + case "": + // empty tag is allowed for some reason + case "-": + ts.Ignored = true + case "nil", "nilString", "nilList": + ts.NilOK = true + if field.Type.Kind != reflect.Ptr { + return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"} + } + switch t { + case "nil": + ts.NilKind = field.Type.Elem.DefaultNilValue() + case "nilString": + ts.NilKind = NilKindString + case "nilList": + ts.NilKind = NilKindList + } + case "optional": + ts.Optional = true + if ts.Tail { + return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`} + } + case "tail": + ts.Tail = true + if field.Index != lastPublic { + return ts, TagError{Field: name, Tag: t, Err: "must be on last field"} + } + if ts.Optional { + return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`} + } + if field.Type.Kind != reflect.Slice { + return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"} + } + default: + return ts, TagError{Field: name, Tag: t, Err: "unknown tag"} + } + } + return ts, nil +} + +func lastPublicField(fields []Field) int { + last := 0 + for _, f := range fields { + if f.Exported { + last = f.Index + } + } + return last +} + +func isUint(k reflect.Kind) bool { + return k >= reflect.Uint && k <= reflect.Uintptr +} + +func isByte(typ Type) bool { + return typ.Kind == reflect.Uint8 && !typ.IsEncoder +} + +func isByteArray(typ Type) bool { + return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem) +} diff --git a/rlp/iterator.go b/rlp/iterator.go new file mode 100644 index 0000000000..6be574572e --- /dev/null +++ b/rlp/iterator.go @@ -0,0 +1,60 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +type listIterator struct { + data []byte + next []byte + err error +} + +// NewListIterator creates an iterator for the (list) represented by data +// TODO: Consider removing this implementation, as it is no longer used. +func NewListIterator(data RawValue) (*listIterator, error) { + k, t, c, err := readKind(data) + if err != nil { + return nil, err + } + if k != List { + return nil, ErrExpectedList + } + it := &listIterator{ + data: data[t : t+c], + } + return it, nil +} + +// Next forwards the iterator one step, returns true if it was not at end yet +func (it *listIterator) Next() bool { + if len(it.data) == 0 { + return false + } + _, t, c, err := readKind(it.data) + it.next = it.data[:t+c] + it.data = it.data[t+c:] + it.err = err + return true +} + +// Value returns the current value +func (it *listIterator) Value() []byte { + return it.next +} + +func (it *listIterator) Err() error { + return it.err +} diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go new file mode 100644 index 0000000000..87c11bdbae --- /dev/null +++ b/rlp/iterator_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "testing" + + "github.com/tomochain/tomochain/common/hexutil" +) + +// TestIterator tests some basic things about the ListIterator. A more +// comprehensive test can be found in core/rlp_test.go, where we can +// use both types and rlp without dependency cycles +func TestIterator(t *testing.T) { + bodyRlpHex := "0xf902cbf8d6f869800182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ba01025c66fad28b4ce3370222624d952c35529e602af7cbe04f667371f61b0e3b3a00ab8813514d1217059748fd903288ace1b4001a4bc5fbde2790debdc8167de2ff869010182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ca05ac4cf1d19be06f3742c21df6c49a7e929ceb3dbaf6a09f3cfb56ff6828bd9a7a06875970133a35e63ac06d360aa166d228cc013e9b96e0a2cae7f55b22e1ee2e8f901f0f901eda0c75448377c0e426b8017b23c5f77379ecf69abc1d5c224284ad3ba1c46c59adaa00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000808080808080a00000000000000000000000000000000000000000000000000000000000000000880000000000000000" + bodyRlp := hexutil.MustDecode(bodyRlpHex) + + it, err := NewListIterator(bodyRlp) + if err != nil { + t.Fatal(err) + } + // Check that txs exist + if !it.Next() { + t.Fatal("expected two elems, got zero") + } + txs := it.Value() + // Check that uncles exist + if !it.Next() { + t.Fatal("expected two elems, got one") + } + txit, err := NewListIterator(txs) + if err != nil { + t.Fatal(err) + } + var i = 0 + for txit.Next() { + if txit.err != nil { + t.Fatal(txit.err) + } + i++ + } + if exp := 2; i != exp { + t.Errorf("count wrong, expected %d got %d", i, exp) + } +} diff --git a/rlp/raw.go b/rlp/raw.go index 2b3f328f66..773aa7e614 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -28,12 +28,53 @@ type RawValue []byte var rawValueType = reflect.TypeOf(RawValue{}) +// StringSize returns the encoded size of a string. +func StringSize(s string) uint64 { + switch { + case len(s) == 0: + return 1 + case len(s) == 1: + if s[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(s))) + len(s)) + } +} + +// BytesSize returns the encoded size of a byte slice. +func BytesSize(b []byte) uint64 { + switch { + case len(b) == 0: + return 1 + case len(b) == 1: + if b[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(b))) + len(b)) + } +} + // ListSize returns the encoded size of an RLP list with the given // content size. func ListSize(contentSize uint64) uint64 { return uint64(headsize(contentSize)) + contentSize } +// IntSize returns the encoded size of the integer x. Note: The return type of this +// function is 'int' for backwards-compatibility reasons. The result is always positive. +func IntSize(x uint64) int { + if x < 0x80 { + return 1 + } + return 1 + intsize(x) +} + // Split returns the content of first RLP value and any // bytes after the value as subslices of b. func Split(b []byte) (k Kind, content, rest []byte, err error) { @@ -57,6 +98,32 @@ func SplitString(b []byte) (content, rest []byte, err error) { return content, rest, nil } +// SplitUint64 decodes an integer at the beginning of b. +// It also returns the remaining data after the integer in 'rest'. +func SplitUint64(b []byte) (x uint64, rest []byte, err error) { + content, rest, err := SplitString(b) + if err != nil { + return 0, b, err + } + switch { + case len(content) == 0: + return 0, rest, nil + case len(content) == 1: + if content[0] == 0 { + return 0, b, ErrCanonInt + } + return uint64(content[0]), rest, nil + case len(content) > 8: + return 0, b, errUintOverflow + default: + x, err = readSize(content, byte(len(content))) + if err != nil { + return 0, b, ErrCanonInt + } + return x, rest, nil + } +} + // SplitList splits b into the content of a list and any remaining // bytes after the list. func SplitList(b []byte) (content, rest []byte, err error) { @@ -154,3 +221,74 @@ func readSize(b []byte, slen byte) (uint64, error) { } return s, nil } + +// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice. +func AppendUint64(b []byte, i uint64) []byte { + if i == 0 { + return append(b, 0x80) + } else if i < 128 { + return append(b, byte(i)) + } + switch { + case i < (1 << 8): + return append(b, 0x81, byte(i)) + case i < (1 << 16): + return append(b, 0x82, + byte(i>>8), + byte(i), + ) + case i < (1 << 24): + return append(b, 0x83, + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 32): + return append(b, 0x84, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 40): + return append(b, 0x85, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + case i < (1 << 48): + return append(b, 0x86, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 56): + return append(b, 0x87, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + default: + return append(b, 0x88, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + } +} diff --git a/rlp/raw_test.go b/rlp/raw_test.go index 2aad042100..7b3255eca3 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -18,9 +18,10 @@ package rlp import ( "bytes" + "errors" "io" - "reflect" "testing" + "testing/quick" ) func TestCountValues(t *testing.T) { @@ -53,21 +54,84 @@ func TestCountValues(t *testing.T) { if count != test.count { t.Errorf("test %d: count mismatch, got %d want %d\ninput: %s", i, count, test.count, test.input) } - if !reflect.DeepEqual(err, test.err) { + if !errors.Is(err, test.err) { t.Errorf("test %d: err mismatch, got %q want %q\ninput: %s", i, err, test.err, test.input) } } } -func TestSplitTypes(t *testing.T) { - if _, _, err := SplitString(unhex("C100")); err != ErrExpectedString { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedString) +func TestSplitString(t *testing.T) { + for i, test := range []string{ + "C0", + "C100", + "C3010203", + "C88363617483646F67", + "F8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitString(unhex(test)); !errors.Is(err, ErrExpectedString) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedString) + } + } +} + +func TestSplitList(t *testing.T) { + for i, test := range []string{ + "80", + "00", + "01", + "8180", + "81FF", + "820400", + "83636174", + "83646F67", + "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitList(unhex(test)); !errors.Is(err, ErrExpectedList) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedList) + } } - if _, _, err := SplitList(unhex("01")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) +} + +func TestSplitUint64(t *testing.T) { + tests := []struct { + input string + val uint64 + rest string + err error + }{ + {"01", 1, "", nil}, + {"7FFF", 0x7F, "FF", nil}, + {"80FF", 0, "FF", nil}, + {"81FAFF", 0xFA, "FF", nil}, + {"82FAFAFF", 0xFAFA, "FF", nil}, + {"83FAFAFAFF", 0xFAFAFA, "FF", nil}, + {"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil}, + {"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil}, + {"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil}, + {"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil}, + {"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil}, + + // errors + {"", 0, "", io.ErrUnexpectedEOF}, + {"00", 0, "00", ErrCanonInt}, + {"81", 0, "81", ErrValueTooLarge}, + {"8100", 0, "8100", ErrCanonSize}, + {"8200FF", 0, "8200FF", ErrCanonInt}, + {"8103FF", 0, "8103FF", ErrCanonSize}, + {"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow}, } - if _, _, err := SplitList(unhex("81FF")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) + + for i, test := range tests { + val, rest, err := SplitUint64(unhex(test.input)) + if val != test.val { + t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input) + } + if !bytes.Equal(rest, unhex(test.rest)) { + t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input) + } + if err != test.err { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } } } @@ -78,7 +142,9 @@ func TestSplit(t *testing.T) { val, rest string err error }{ + {input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"}, {input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"}, + {input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"}, {input: "80FFFF", kind: String, val: "", rest: "FFFF"}, {input: "C3010203", kind: List, val: "010203"}, @@ -194,3 +260,79 @@ func TestReadSize(t *testing.T) { } } } + +func TestAppendUint64(t *testing.T) { + tests := []struct { + input uint64 + slice []byte + output string + }{ + {0, nil, "80"}, + {1, nil, "01"}, + {2, nil, "02"}, + {127, nil, "7F"}, + {128, nil, "8180"}, + {129, nil, "8181"}, + {0xFFFFFF, nil, "83FFFFFF"}, + {127, []byte{1, 2, 3}, "0102037F"}, + {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"}, + } + + for _, test := range tests { + x := AppendUint64(test.slice, test.input) + if !bytes.Equal(x, unhex(test.output)) { + t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output) + } + + // Check that IntSize returns the appended size. + length := len(x) - len(test.slice) + if s := IntSize(test.input); s != length { + t.Errorf("IntSize(%d): got %d, want %d", test.input, s, length) + } + } +} + +func TestAppendUint64Random(t *testing.T) { + fn := func(i uint64) bool { + enc, _ := EncodeToBytes(i) + encAppend := AppendUint64(nil, i) + return bytes.Equal(enc, encAppend) + } + config := quick.Config{MaxCountScale: 50} + if err := quick.Check(fn, &config); err != nil { + t.Fatal(err) + } +} + +func TestBytesSize(t *testing.T) { + tests := []struct { + v []byte + size uint64 + }{ + {v: []byte{}, size: 1}, + {v: []byte{0x1}, size: 1}, + {v: []byte{0x7E}, size: 1}, + {v: []byte{0x7F}, size: 1}, + {v: []byte{0x80}, size: 2}, + {v: []byte{0xFF}, size: 2}, + {v: []byte{0xFF, 0xF0}, size: 3}, + {v: make([]byte, 55), size: 56}, + {v: make([]byte, 56), size: 58}, + } + + for _, test := range tests { + s := BytesSize(test.v) + if s != test.size { + t.Errorf("BytesSize(%#x) -> %d, want %d", test.v, s, test.size) + } + s = StringSize(string(test.v)) + if s != test.size { + t.Errorf("StringSize(%#x) -> %d, want %d", test.v, s, test.size) + } + // Sanity check: + enc, _ := EncodeToBytes(test.v) + if uint64(len(enc)) != test.size { + t.Errorf("len(EncodeToBytes(%#x)) -> %d, test says %d", test.v, len(enc), test.size) + } + } +} diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go new file mode 100644 index 0000000000..26ccdc574e --- /dev/null +++ b/rlp/rlpgen/gen.go @@ -0,0 +1,800 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/types" + "sort" + + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" +) + +// buildContext keeps the data needed for make*Op. +type buildContext struct { + topType *types.Named // the type we're creating methods for + + encoderIface *types.Interface + decoderIface *types.Interface + rawValueType *types.Named + + typeToStructCache map[types.Type]*rlpstruct.Type +} + +func newBuildContext(packageRLP *types.Package) *buildContext { + enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying() + dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying() + rawv := packageRLP.Scope().Lookup("RawValue").Type() + return &buildContext{ + typeToStructCache: make(map[types.Type]*rlpstruct.Type), + encoderIface: enc.(*types.Interface), + decoderIface: dec.(*types.Interface), + rawValueType: rawv.(*types.Named), + } +} + +func (bctx *buildContext) isEncoder(typ types.Type) bool { + return types.Implements(typ, bctx.encoderIface) +} + +func (bctx *buildContext) isDecoder(typ types.Type) bool { + return types.Implements(typ, bctx.decoderIface) +} + +// typeToStructType converts typ to rlpstruct.Type. +func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { + if prev := bctx.typeToStructCache[typ]; prev != nil { + return prev // short-circuit for recursive types. + } + + // Resolve named types to their underlying type, but keep the name. + name := types.TypeString(typ, nil) + for { + utype := typ.Underlying() + if utype == typ { + break + } + typ = utype + } + + // Create the type and store it in cache. + t := &rlpstruct.Type{ + Name: name, + Kind: typeReflectKind(typ), + IsEncoder: bctx.isEncoder(typ), + IsDecoder: bctx.isDecoder(typ), + } + bctx.typeToStructCache[typ] = t + + // Assign element type. + switch typ.(type) { + case *types.Array, *types.Slice, *types.Pointer: + etype := typ.(interface{ Elem() types.Type }).Elem() + t.Elem = bctx.typeToStructType(etype) + } + return t +} + +// genContext is passed to the gen* methods of op when generating +// the output code. It tracks packages to be imported by the output +// file and assigns unique names of temporary variables. +type genContext struct { + inPackage *types.Package + imports map[string]struct{} + tempCounter int +} + +func newGenContext(inPackage *types.Package) *genContext { + return &genContext{ + inPackage: inPackage, + imports: make(map[string]struct{}), + } +} + +func (ctx *genContext) temp() string { + v := fmt.Sprintf("_tmp%d", ctx.tempCounter) + ctx.tempCounter++ + return v +} + +func (ctx *genContext) resetTemp() { + ctx.tempCounter = 0 +} + +func (ctx *genContext) addImport(path string) { + if path == ctx.inPackage.Path() { + return // avoid importing the package that we're generating in. + } + // TODO: renaming? + ctx.imports[path] = struct{}{} +} + +// importsList returns all packages that need to be imported. +func (ctx *genContext) importsList() []string { + imp := make([]string, 0, len(ctx.imports)) + for k := range ctx.imports { + imp = append(imp, k) + } + sort.Strings(imp) + return imp +} + +// qualify is the types.Qualifier used for printing types. +func (ctx *genContext) qualify(pkg *types.Package) string { + if pkg.Path() == ctx.inPackage.Path() { + return "" + } + ctx.addImport(pkg.Path()) + // TODO: renaming? + return pkg.Name() +} + +type op interface { + // genWrite creates the encoder. The generated code should write v, + // which is any Go expression, to the rlp.EncoderBuffer 'w'. + genWrite(ctx *genContext, v string) string + + // genDecode creates the decoder. The generated code should read + // a value from the rlp.Stream 'dec' and store it to dst. + genDecode(ctx *genContext) (string, string) +} + +// basicOp handles basic types bool, uint*, string. +type basicOp struct { + typ types.Type + writeMethod string // calle write the value + writeArgType types.Type // parameter type of writeMethod + decMethod string + decResultType types.Type // return type of decMethod + decUseBitSize bool // if true, result bit size is appended to decMethod +} + +func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) { + op := basicOp{typ: typ} + kind := typ.Kind() + switch { + case kind == types.Bool: + op.writeMethod = "WriteBool" + op.writeArgType = types.Typ[types.Bool] + op.decMethod = "Bool" + op.decResultType = types.Typ[types.Bool] + case kind >= types.Uint8 && kind <= types.Uint64: + op.writeMethod = "WriteUint64" + op.writeArgType = types.Typ[types.Uint64] + op.decMethod = "Uint" + op.decResultType = typ + op.decUseBitSize = true + case kind == types.String: + op.writeMethod = "WriteString" + op.writeArgType = types.Typ[types.String] + op.decMethod = "String" + op.decResultType = types.Typ[types.String] + default: + return nil, fmt.Errorf("unhandled basic type: %v", typ) + } + return op, nil +} + +func (*buildContext) makeByteSliceOp(typ *types.Slice) op { + if !isByte(typ.Elem()) { + panic("non-byte slice type in makeByteSliceOp") + } + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: typ, + writeMethod: "WriteBytes", + writeArgType: bslice, + decMethod: "Bytes", + decResultType: bslice, + } +} + +func (bctx *buildContext) makeRawValueOp() op { + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: bctx.rawValueType, + writeMethod: "Write", + writeArgType: bslice, + decMethod: "Raw", + decResultType: bslice, + } +} + +func (op basicOp) writeNeedsConversion() bool { + return !types.AssignableTo(op.typ, op.writeArgType) +} + +func (op basicOp) decodeNeedsConversion() bool { + return !types.AssignableTo(op.decResultType, op.typ) +} + +func (op basicOp) genWrite(ctx *genContext, v string) string { + if op.writeNeedsConversion() { + v = fmt.Sprintf("%s(%s)", op.writeArgType, v) + } + return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v) +} + +func (op basicOp) genDecode(ctx *genContext) (string, string) { + var ( + resultV = ctx.temp() + result = resultV + method = op.decMethod + ) + if op.decUseBitSize { + // Note: For now, this only works for platform-independent integer + // sizes. makeBasicOp forbids the platform-dependent types. + var sizes types.StdSizes + method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8) + } + + // Call the decoder method. + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method) + fmt.Fprintf(&b, "if err != nil { return err }\n") + if op.decodeNeedsConversion() { + conv := ctx.temp() + fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV) + result = conv + } + return result, b.String() +} + +// byteArrayOp handles [...]byte. +type byteArrayOp struct { + typ types.Type + name types.Type // name != typ for named byte array types (e.g. common.Address) +} + +func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp { + nt := types.Type(name) + if name == nil { + nt = typ + } + return byteArrayOp{typ, nt} +} + +func (op byteArrayOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("w.WriteBytes(%s[:])\n", v) +} + +func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify)) + fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// bigIntOp handles big.Int. +// This exists because big.Int has it's own decoder operation on rlp.Stream, +// but the decode method returns *big.Int, so it needs to be dereferenced. +type bigIntOp struct { + pointer bool +} + +func (op bigIntOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v) + fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n") + fmt.Fprintf(&b, "}\n") + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op bigIntOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV) + fmt.Fprintf(&b, "if err != nil { return err }\n") + + result := resultV + if !op.pointer { + result = "(*" + resultV + ")" + } + return result, b.String() +} + +// uint256Op handles "github.com/holiman/uint256".Int +type uint256Op struct { + pointer bool +} + +func (op uint256Op) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op uint256Op) genDecode(ctx *genContext) (string, string) { + ctx.addImport("github.com/holiman/uint256") + + var b bytes.Buffer + resultV := ctx.temp() + fmt.Fprintf(&b, "var %s uint256.Int\n", resultV) + fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV) + + result := resultV + if op.pointer { + result = "&" + resultV + } + return result, b.String() +} + +// encoderDecoderOp handles rlp.Encoder and rlp.Decoder. +// In order to be used with this, the type must implement both interfaces. +// This restriction may be lifted in the future by creating separate ops for +// encoding and decoding. +type encoderDecoderOp struct { + typ types.Type +} + +func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v) +} + +func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) { + // DecodeRLP must have pointer receiver, and this is verified in makeOp. + etyp := op.typ.(*types.Pointer).Elem() + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify)) + fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// ptrOp handles pointer types. +type ptrOp struct { + elemTyp types.Type + elem op + nilOK bool + nilValue rlpstruct.NilKind +} + +func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) { + elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + op := ptrOp{elemTyp: elemTyp, elem: elemOp} + + // Determine nil value. + if tags.NilOK { + op.nilOK = true + op.nilValue = tags.NilKind + } else { + styp := bctx.typeToStructType(elemTyp) + op.nilValue = styp.DefaultNilValue() + } + return op, nil +} + +func (op ptrOp) genWrite(ctx *genContext, v string) string { + // Note: in writer functions, accesses to v are read-only, i.e. v is any Go + // expression. To make all accesses work through the pointer, we substitute + // v with (*v). This is required for most accesses including `v`, `call(v)`, + // and `v[index]` on slices. + // + // For `v.field` and `v[:]` on arrays, the dereference operation is not required. + var vv string + _, isStruct := op.elem.(structOp) + _, isByteArray := op.elem.(byteArrayOp) + if isStruct || isByteArray { + vv = v + } else { + vv = fmt.Sprintf("(*%s)", v) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue) + fmt.Fprintf(&b, "} else {\n") + fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv)) + fmt.Fprintf(&b, "}\n") + return b.String() +} + +func (op ptrOp) genDecode(ctx *genContext) (string, string) { + result, code := op.elem.genDecode(ctx) + if !op.nilOK { + // If nil pointers are not allowed, we can just decode the element. + return "&" + result, code + } + + // nil is allowed, so check the kind and size first. + // If size is zero and kind matches the nilKind of the type, + // the value decodes as a nil pointer. + var ( + resultV = ctx.temp() + kindV = ctx.temp() + sizeV = ctx.temp() + wantKind string + ) + if op.nilValue == rlpstruct.NilKindList { + wantKind = "rlp.List" + } else { + wantKind = "rlp.String" + } + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify)) + fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV) + fmt.Fprintf(&b, " return err\n") + fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " %s = &%s\n", resultV, result) + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +// structOp handles struct types. +type structOp struct { + named *types.Named + typ *types.Struct + fields []*structField + optionalFields []*structField +} + +type structField struct { + name string + typ types.Type + elem op +} + +func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) { + // Convert fields to []rlpstruct.Field. + var allStructFields []rlpstruct.Field + for i := 0; i < typ.NumFields(); i++ { + f := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: f.Name(), + Exported: f.Exported(), + Index: i, + Tag: typ.Tag(i), + Type: *bctx.typeToStructType(f.Type()), + }) + } + + // Filter/validate fields. + fields, tags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + return nil, err + } + + // Create field ops. + var op = structOp{named: named, typ: typ} + for i, field := range fields { + // Advanced struct tags are not supported yet. + tag := tags[i] + if err := checkUnsupportedTags(field.Name, tag); err != nil { + return nil, err + } + typ := typ.Field(field.Index).Type() + elem, err := bctx.makeOp(nil, typ, tags[i]) + if err != nil { + return nil, fmt.Errorf("field %s: %v", field.Name, err) + } + f := &structField{name: field.Name, typ: typ, elem: elem} + if tag.Optional { + op.optionalFields = append(op.optionalFields, f) + } else { + op.fields = append(op.fields, f) + } + } + return op, nil +} + +func checkUnsupportedTags(field string, tag rlpstruct.Tags) error { + if tag.Tail { + return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field) + } + return nil +} + +func (op structOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + var listMarker = ctx.temp() + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + for _, field := range op.fields { + selector := v + "." + field.name + fmt.Fprint(&b, field.elem.genWrite(ctx, selector)) + } + op.writeOptionalFields(&b, ctx, v) + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) { + if len(op.optionalFields) == 0 { + return + } + // First check zero-ness of all optional fields. + var zeroV = make([]string, len(op.optionalFields)) + for i, field := range op.optionalFields { + selector := v + "." + field.name + zeroV[i] = ctx.temp() + fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify)) + } + // Now write the fields. + for i, field := range op.optionalFields { + selector := v + "." + field.name + cond := "" + for j := i; j < len(op.optionalFields); j++ { + if j > i { + cond += " || " + } + cond += zeroV[j] + } + fmt.Fprintf(b, "if %s {\n", cond) + fmt.Fprint(b, field.elem.genWrite(ctx, selector)) + fmt.Fprintf(b, "}\n") + } +} + +func (op structOp) genDecode(ctx *genContext) (string, string) { + // Get the string representation of the type. + // Here, named types are handled separately because the output + // would contain a copy of the struct definition otherwise. + var typeName string + if op.named != nil { + typeName = types.TypeString(op.named, ctx.qualify) + } else { + typeName = types.TypeString(op.typ, ctx.qualify) + } + + // Create struct object. + var resultV = ctx.temp() + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, typeName) + + // Decode fields. + fmt.Fprintf(&b, "{\n") + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + for _, field := range op.fields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(&b, "// %s:\n", field.name) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result) + } + op.decodeOptionalFields(&b, ctx, resultV) + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) { + var suffix bytes.Buffer + for _, field := range op.optionalFields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(b, "// %s:\n", field.name) + fmt.Fprintf(b, "if dec.MoreDataInList() {\n") + fmt.Fprint(b, code) + fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result) + fmt.Fprintf(&suffix, "}\n") + } + suffix.WriteTo(b) +} + +// sliceOp handles slice types. +type sliceOp struct { + typ *types.Slice + elemOp op +} + +func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) { + elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{}) + if err != nil { + return nil, err + } + return sliceOp{typ: typ, elemOp: elemOp}, nil +} + +func (op sliceOp) genWrite(ctx *genContext, v string) string { + var ( + listMarker = ctx.temp() // holds return value of w.List() + iterElemV = ctx.temp() // iteration variable + elemCode = op.elemOp.genWrite(ctx, iterElemV) + ) + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v) + fmt.Fprint(&b, elemCode) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op sliceOp) genDecode(ctx *genContext) (string, string) { + var sliceV = ctx.temp() // holds the output slice + elemResult, elemCode := op.elemOp.genDecode(ctx) + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify)) + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + fmt.Fprintf(&b, "for dec.MoreDataInList() {\n") + fmt.Fprintf(&b, " %s", elemCode) + fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + return sliceV, b.String() +} + +func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) { + switch typ := typ.(type) { + case *types.Named: + if isBigInt(typ) { + return bigIntOp{}, nil + } + if isUint256(typ) { + return uint256Op{}, nil + } + if typ == bctx.rawValueType { + return bctx.makeRawValueOp(), nil + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ) + } + // TODO: same check for encoder? + return bctx.makeOp(typ, typ.Underlying(), tags) + case *types.Pointer: + if isBigInt(typ.Elem()) { + return bigIntOp{pointer: true}, nil + } + if isUint256(typ.Elem()) { + return uint256Op{pointer: true}, nil + } + // Encoder/Decoder interfaces. + if bctx.isEncoder(typ) { + if bctx.isDecoder(typ) { + return encoderDecoderOp{typ}, nil + } + return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ) + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ) + } + // Default pointer handling. + return bctx.makePtrOp(typ.Elem(), tags) + case *types.Basic: + return bctx.makeBasicOp(typ) + case *types.Struct: + return bctx.makeStructOp(name, typ) + case *types.Slice: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteSliceOp(typ), nil + } + return bctx.makeSliceOp(typ) + case *types.Array: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteArrayOp(name, typ), nil + } + return nil, fmt.Errorf("unhandled array type: %v", typ) + default: + return nil, fmt.Errorf("unhandled type: %v", typ) + } +} + +// generateDecoder generates the DecodeRLP method on 'typ'. +func generateDecoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport(pathOfPackageRLP) + + result, code := op.genDecode(ctx) + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " *obj = %s\n", result) + fmt.Fprintf(&b, " return nil\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +// generateEncoder generates the EncodeRLP method on 'typ'. +func generateEncoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport("io") + ctx.addImport(pathOfPackageRLP) + + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) + fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n") + fmt.Fprint(&b, op.genWrite(ctx, "obj")) + fmt.Fprintf(&b, " return w.Flush()\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) { + bctx.topType = typ + + pkg := typ.Obj().Pkg() + op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + + var ( + ctx = newGenContext(pkg) + encSource []byte + decSource []byte + ) + if encoder { + encSource = generateEncoder(ctx, typ.Obj().Name(), op) + } + if decoder { + decSource = generateDecoder(ctx, typ.Obj().Name(), op) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) + for _, imp := range ctx.importsList() { + fmt.Fprintf(&b, "import %q\n", imp) + } + if encoder { + fmt.Fprintln(&b) + b.Write(encSource) + } + if decoder { + fmt.Fprintln(&b) + b.Write(decSource) + } + + source := b.Bytes() + // fmt.Println(string(source)) + return format.Source(source) +} diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go new file mode 100644 index 0000000000..3b4f5df287 --- /dev/null +++ b/rlp/rlpgen/gen_test.go @@ -0,0 +1,107 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "os" + "path/filepath" + "testing" +) + +// Package RLP is loaded only once and reused for all tests. +var ( + testFset = token.NewFileSet() + testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom) + testPackageRLP *types.Package +) + +func init() { + cwd, err := os.Getwd() + if err != nil { + panic(err) + } + testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0) + if err != nil { + panic(fmt.Errorf("can't load package RLP: %v", err)) + } +} + +var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} + +func TestOutput(t *testing.T) { + for _, test := range tests { + test := test + t.Run(test, func(t *testing.T) { + inputFile := filepath.Join("testdata", test+".in.txt") + outputFile := filepath.Join("testdata", test+".out.txt") + bctx, typ, err := loadTestSource(inputFile, "Test") + if err != nil { + t.Fatal("error loading test source:", err) + } + output, err := bctx.generate(typ, true, true) + if err != nil { + t.Fatal("error in generate:", err) + } + + // Set this environment variable to regenerate the test outputs. + if os.Getenv("WRITE_TEST_FILES") != "" { + os.WriteFile(outputFile, output, 0644) + } + + // Check if output matches. + wantOutput, err := os.ReadFile(outputFile) + if err != nil { + t.Fatal("error loading expected test output:", err) + } + if !bytes.Equal(output, wantOutput) { + t.Fatalf("output mismatch, want: %v got %v", string(wantOutput), string(output)) + } + }) + } +} + +func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) { + // Load the test input. + content, err := os.ReadFile(file) + if err != nil { + return nil, nil, err + } + f, err := parser.ParseFile(testFset, file, content, 0) + if err != nil { + return nil, nil, err + } + conf := types.Config{Importer: testImporter} + pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil) + if err != nil { + return nil, nil, err + } + + // Find the test struct. + bctx := newBuildContext(testPackageRLP) + typ, err := lookupStructType(pkg.Scope(), typeName) + if err != nil { + return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err) + } + return bctx, typ, nil +} diff --git a/rlp/rlpgen/main.go b/rlp/rlpgen/main.go new file mode 100644 index 0000000000..87aebbc47a --- /dev/null +++ b/rlp/rlpgen/main.go @@ -0,0 +1,147 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/types" + "os" + + "golang.org/x/tools/go/packages" +) + +const pathOfPackageRLP = "github.com/tomochain/tomochain/rlp" + +func main() { + var ( + pkgdir = flag.String("dir", ".", "input package") + output = flag.String("out", "-", "output file (default is stdout)") + genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?") + genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?") + typename = flag.String("type", "", "type to generate methods for") + ) + flag.Parse() + + cfg := Config{ + Dir: *pkgdir, + Type: *typename, + GenerateEncoder: *genEncoder, + GenerateDecoder: *genDecoder, + } + code, err := cfg.process() + if err != nil { + fatal(err) + } + if *output == "-" { + os.Stdout.Write(code) + } else if err := os.WriteFile(*output, code, 0600); err != nil { + fatal(err) + } +} + +func fatal(args ...interface{}) { + fmt.Fprintln(os.Stderr, args...) + os.Exit(1) +} + +type Config struct { + Dir string // input package directory + Type string + + GenerateEncoder bool + GenerateDecoder bool +} + +// process generates the Go code. +func (cfg *Config) process() (code []byte, err error) { + // Load packages. + pcfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedImports | packages.NeedDeps, + Dir: cfg.Dir, + BuildFlags: []string{"-tags", "norlpgen"}, + } + ps, err := packages.Load(pcfg, pathOfPackageRLP, ".") + if err != nil { + return nil, err + } + if len(ps) == 0 { + return nil, fmt.Errorf("no Go package found in %s", cfg.Dir) + } + packages.PrintErrors(ps) + + // Find the packages that were loaded. + var ( + pkg *types.Package + packageRLP *types.Package + ) + for _, p := range ps { + if len(p.Errors) > 0 { + return nil, fmt.Errorf("package %s has errors", p.PkgPath) + } + if p.PkgPath == pathOfPackageRLP { + packageRLP = p.Types + } else { + pkg = p.Types + } + } + bctx := newBuildContext(packageRLP) + + // Find the type and generate. + typ, err := lookupStructType(pkg.Scope(), cfg.Type) + if err != nil { + return nil, fmt.Errorf("can't find %s in %s: %v", cfg.Type, pkg, err) + } + code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder) + if err != nil { + return nil, err + } + + // Add build comments. + // This is done here to avoid processing these lines with gofmt. + var header bytes.Buffer + fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n") + fmt.Fprint(&header, "//go:build !norlpgen\n") + fmt.Fprint(&header, "// +build !norlpgen\n\n") + return append(header.Bytes(), code...), nil +} + +func lookupStructType(scope *types.Scope, name string) (*types.Named, error) { + typ, err := lookupType(scope, name) + if err != nil { + return nil, err + } + _, ok := typ.Underlying().(*types.Struct) + if !ok { + return nil, errors.New("not a struct type") + } + return typ, nil +} + +func lookupType(scope *types.Scope, name string) (*types.Named, error) { + obj := scope.Lookup(name) + if obj == nil { + return nil, errors.New("no such identifier") + } + typ, ok := obj.(*types.TypeName) + if !ok { + return nil, errors.New("not a type") + } + return typ.Type().(*types.Named), nil +} diff --git a/rlp/rlpgen/testdata/bigint.in.txt b/rlp/rlpgen/testdata/bigint.in.txt new file mode 100644 index 0000000000..d23d84a287 --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "math/big" + +type Test struct { + Int *big.Int + IntNoPtr big.Int +} diff --git a/rlp/rlpgen/testdata/bigint.out.txt b/rlp/rlpgen/testdata/bigint.out.txt new file mode 100644 index 0000000000..6dc7bea3bf --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.out.txt @@ -0,0 +1,49 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Int.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Int) + } + if obj.IntNoPtr.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + _tmp1, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.Int = _tmp1 + // IntNoPtr: + _tmp2, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.IntNoPtr = (*_tmp2) + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/nil.in.txt b/rlp/rlpgen/testdata/nil.in.txt new file mode 100644 index 0000000000..a28ff34487 --- /dev/null +++ b/rlp/rlpgen/testdata/nil.in.txt @@ -0,0 +1,30 @@ +// -*- mode: go -*- + +package test + +type Aux struct{ + A uint32 +} + +type Test struct{ + Uint8 *byte `rlp:"nil"` + Uint8List *byte `rlp:"nilList"` + + Uint32 *uint32 `rlp:"nil"` + Uint32List *uint32 `rlp:"nilList"` + + Uint64 *uint64 `rlp:"nil"` + Uint64List *uint64 `rlp:"nilList"` + + String *string `rlp:"nil"` + StringList *string `rlp:"nilList"` + + ByteArray *[3]byte `rlp:"nil"` + ByteArrayList *[3]byte `rlp:"nilList"` + + ByteSlice *[]byte `rlp:"nil"` + ByteSliceList *[]byte `rlp:"nilList"` + + Struct *Aux `rlp:"nil"` + StructString *Aux `rlp:"nilString"` +} diff --git a/rlp/rlpgen/testdata/nil.out.txt b/rlp/rlpgen/testdata/nil.out.txt new file mode 100644 index 0000000000..b3bdd0b86f --- /dev/null +++ b/rlp/rlpgen/testdata/nil.out.txt @@ -0,0 +1,289 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Uint8 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint8))) + } + if obj.Uint8List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint8List))) + } + if obj.Uint32 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint32))) + } + if obj.Uint32List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint32List))) + } + if obj.Uint64 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Uint64)) + } + if obj.Uint64List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64((*obj.Uint64List)) + } + if obj.String == nil { + w.Write([]byte{0x80}) + } else { + w.WriteString((*obj.String)) + } + if obj.StringList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteString((*obj.StringList)) + } + if obj.ByteArray == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes(obj.ByteArray[:]) + } + if obj.ByteArrayList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes(obj.ByteArrayList[:]) + } + if obj.ByteSlice == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes((*obj.ByteSlice)) + } + if obj.ByteSliceList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes((*obj.ByteSliceList)) + } + if obj.Struct == nil { + w.Write([]byte{0xC0}) + } else { + _tmp1 := w.List() + w.WriteUint64(uint64(obj.Struct.A)) + w.ListEnd(_tmp1) + } + if obj.StructString == nil { + w.Write([]byte{0x80}) + } else { + _tmp2 := w.List() + w.WriteUint64(uint64(obj.StructString.A)) + w.ListEnd(_tmp2) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint8: + var _tmp2 *byte + if _tmp3, _tmp4, err := dec.Kind(); err != nil { + return err + } else if _tmp4 != 0 || _tmp3 != rlp.String { + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp2 = &_tmp1 + } + _tmp0.Uint8 = _tmp2 + // Uint8List: + var _tmp6 *byte + if _tmp7, _tmp8, err := dec.Kind(); err != nil { + return err + } else if _tmp8 != 0 || _tmp7 != rlp.List { + _tmp5, err := dec.Uint8() + if err != nil { + return err + } + _tmp6 = &_tmp5 + } + _tmp0.Uint8List = _tmp6 + // Uint32: + var _tmp10 *uint32 + if _tmp11, _tmp12, err := dec.Kind(); err != nil { + return err + } else if _tmp12 != 0 || _tmp11 != rlp.String { + _tmp9, err := dec.Uint32() + if err != nil { + return err + } + _tmp10 = &_tmp9 + } + _tmp0.Uint32 = _tmp10 + // Uint32List: + var _tmp14 *uint32 + if _tmp15, _tmp16, err := dec.Kind(); err != nil { + return err + } else if _tmp16 != 0 || _tmp15 != rlp.List { + _tmp13, err := dec.Uint32() + if err != nil { + return err + } + _tmp14 = &_tmp13 + } + _tmp0.Uint32List = _tmp14 + // Uint64: + var _tmp18 *uint64 + if _tmp19, _tmp20, err := dec.Kind(); err != nil { + return err + } else if _tmp20 != 0 || _tmp19 != rlp.String { + _tmp17, err := dec.Uint64() + if err != nil { + return err + } + _tmp18 = &_tmp17 + } + _tmp0.Uint64 = _tmp18 + // Uint64List: + var _tmp22 *uint64 + if _tmp23, _tmp24, err := dec.Kind(); err != nil { + return err + } else if _tmp24 != 0 || _tmp23 != rlp.List { + _tmp21, err := dec.Uint64() + if err != nil { + return err + } + _tmp22 = &_tmp21 + } + _tmp0.Uint64List = _tmp22 + // String: + var _tmp26 *string + if _tmp27, _tmp28, err := dec.Kind(); err != nil { + return err + } else if _tmp28 != 0 || _tmp27 != rlp.String { + _tmp25, err := dec.String() + if err != nil { + return err + } + _tmp26 = &_tmp25 + } + _tmp0.String = _tmp26 + // StringList: + var _tmp30 *string + if _tmp31, _tmp32, err := dec.Kind(); err != nil { + return err + } else if _tmp32 != 0 || _tmp31 != rlp.List { + _tmp29, err := dec.String() + if err != nil { + return err + } + _tmp30 = &_tmp29 + } + _tmp0.StringList = _tmp30 + // ByteArray: + var _tmp34 *[3]byte + if _tmp35, _tmp36, err := dec.Kind(); err != nil { + return err + } else if _tmp36 != 0 || _tmp35 != rlp.String { + var _tmp33 [3]byte + if err := dec.ReadBytes(_tmp33[:]); err != nil { + return err + } + _tmp34 = &_tmp33 + } + _tmp0.ByteArray = _tmp34 + // ByteArrayList: + var _tmp38 *[3]byte + if _tmp39, _tmp40, err := dec.Kind(); err != nil { + return err + } else if _tmp40 != 0 || _tmp39 != rlp.List { + var _tmp37 [3]byte + if err := dec.ReadBytes(_tmp37[:]); err != nil { + return err + } + _tmp38 = &_tmp37 + } + _tmp0.ByteArrayList = _tmp38 + // ByteSlice: + var _tmp42 *[]byte + if _tmp43, _tmp44, err := dec.Kind(); err != nil { + return err + } else if _tmp44 != 0 || _tmp43 != rlp.String { + _tmp41, err := dec.Bytes() + if err != nil { + return err + } + _tmp42 = &_tmp41 + } + _tmp0.ByteSlice = _tmp42 + // ByteSliceList: + var _tmp46 *[]byte + if _tmp47, _tmp48, err := dec.Kind(); err != nil { + return err + } else if _tmp48 != 0 || _tmp47 != rlp.List { + _tmp45, err := dec.Bytes() + if err != nil { + return err + } + _tmp46 = &_tmp45 + } + _tmp0.ByteSliceList = _tmp46 + // Struct: + var _tmp51 *Aux + if _tmp52, _tmp53, err := dec.Kind(); err != nil { + return err + } else if _tmp53 != 0 || _tmp52 != rlp.List { + var _tmp49 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp50, err := dec.Uint32() + if err != nil { + return err + } + _tmp49.A = _tmp50 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp51 = &_tmp49 + } + _tmp0.Struct = _tmp51 + // StructString: + var _tmp56 *Aux + if _tmp57, _tmp58, err := dec.Kind(); err != nil { + return err + } else if _tmp58 != 0 || _tmp57 != rlp.String { + var _tmp54 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp55, err := dec.Uint32() + if err != nil { + return err + } + _tmp54.A = _tmp55 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp56 = &_tmp54 + } + _tmp0.StructString = _tmp56 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/optional.in.txt b/rlp/rlpgen/testdata/optional.in.txt new file mode 100644 index 0000000000..f1ac9f7899 --- /dev/null +++ b/rlp/rlpgen/testdata/optional.in.txt @@ -0,0 +1,17 @@ +// -*- mode: go -*- + +package test + +type Aux struct { + A uint64 +} + +type Test struct { + Uint64 uint64 `rlp:"optional"` + Pointer *uint64 `rlp:"optional"` + String string `rlp:"optional"` + Slice []uint64 `rlp:"optional"` + Array [3]byte `rlp:"optional"` + NamedStruct Aux `rlp:"optional"` + AnonStruct struct{ A string } `rlp:"optional"` +} diff --git a/rlp/rlpgen/testdata/optional.out.txt b/rlp/rlpgen/testdata/optional.out.txt new file mode 100644 index 0000000000..fb9b95d44d --- /dev/null +++ b/rlp/rlpgen/testdata/optional.out.txt @@ -0,0 +1,153 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + _tmp1 := obj.Uint64 != 0 + _tmp2 := obj.Pointer != nil + _tmp3 := obj.String != "" + _tmp4 := len(obj.Slice) > 0 + _tmp5 := obj.Array != ([3]byte{}) + _tmp6 := obj.NamedStruct != (Aux{}) + _tmp7 := obj.AnonStruct != (struct{ A string }{}) + if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteUint64(obj.Uint64) + } + if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if obj.Pointer == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Pointer)) + } + } + if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteString(obj.String) + } + if _tmp4 || _tmp5 || _tmp6 || _tmp7 { + _tmp8 := w.List() + for _, _tmp9 := range obj.Slice { + w.WriteUint64(_tmp9) + } + w.ListEnd(_tmp8) + } + if _tmp5 || _tmp6 || _tmp7 { + w.WriteBytes(obj.Array[:]) + } + if _tmp6 || _tmp7 { + _tmp10 := w.List() + w.WriteUint64(obj.NamedStruct.A) + w.ListEnd(_tmp10) + } + if _tmp7 { + _tmp11 := w.List() + w.WriteString(obj.AnonStruct.A) + w.ListEnd(_tmp11) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint64: + if dec.MoreDataInList() { + _tmp1, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Uint64 = _tmp1 + // Pointer: + if dec.MoreDataInList() { + _tmp2, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Pointer = &_tmp2 + // String: + if dec.MoreDataInList() { + _tmp3, err := dec.String() + if err != nil { + return err + } + _tmp0.String = _tmp3 + // Slice: + if dec.MoreDataInList() { + var _tmp4 []uint64 + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp5, err := dec.Uint64() + if err != nil { + return err + } + _tmp4 = append(_tmp4, _tmp5) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.Slice = _tmp4 + // Array: + if dec.MoreDataInList() { + var _tmp6 [3]byte + if err := dec.ReadBytes(_tmp6[:]); err != nil { + return err + } + _tmp0.Array = _tmp6 + // NamedStruct: + if dec.MoreDataInList() { + var _tmp7 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp8, err := dec.Uint64() + if err != nil { + return err + } + _tmp7.A = _tmp8 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.NamedStruct = _tmp7 + // AnonStruct: + if dec.MoreDataInList() { + var _tmp9 struct{ A string } + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp10, err := dec.String() + if err != nil { + return err + } + _tmp9.A = _tmp10 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.AnonStruct = _tmp9 + } + } + } + } + } + } + } + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/rawvalue.in.txt b/rlp/rlpgen/testdata/rawvalue.in.txt new file mode 100644 index 0000000000..6c17849954 --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.in.txt @@ -0,0 +1,11 @@ +// -*- mode: go -*- + +package test + +import "github.com/tomochain/tomochain/rlp" + +type Test struct { + RawValue rlp.RawValue + PointerToRawValue *rlp.RawValue + SliceOfRawValue []rlp.RawValue +} diff --git a/rlp/rlpgen/testdata/rawvalue.out.txt b/rlp/rlpgen/testdata/rawvalue.out.txt new file mode 100644 index 0000000000..4b6eb385d6 --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.out.txt @@ -0,0 +1,64 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.Write(obj.RawValue) + if obj.PointerToRawValue == nil { + w.Write([]byte{0x80}) + } else { + w.Write((*obj.PointerToRawValue)) + } + _tmp1 := w.List() + for _, _tmp2 := range obj.SliceOfRawValue { + w.Write(_tmp2) + } + w.ListEnd(_tmp1) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // RawValue: + _tmp1, err := dec.Raw() + if err != nil { + return err + } + _tmp0.RawValue = _tmp1 + // PointerToRawValue: + _tmp2, err := dec.Raw() + if err != nil { + return err + } + _tmp0.PointerToRawValue = &_tmp2 + // SliceOfRawValue: + var _tmp3 []rlp.RawValue + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp4, err := dec.Raw() + if err != nil { + return err + } + _tmp3 = append(_tmp3, _tmp4) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.SliceOfRawValue = _tmp3 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uint256.in.txt b/rlp/rlpgen/testdata/uint256.in.txt new file mode 100644 index 0000000000..ed16e0a788 --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "github.com/holiman/uint256" + +type Test struct { + Int *uint256.Int + IntNoPtr uint256.Int +} diff --git a/rlp/rlpgen/testdata/uint256.out.txt b/rlp/rlpgen/testdata/uint256.out.txt new file mode 100644 index 0000000000..5d99ca2e6d --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.out.txt @@ -0,0 +1,44 @@ +package test + +import "github.com/holiman/uint256" +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteUint256(obj.Int) + } + w.WriteUint256(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + var _tmp1 uint256.Int + if err := dec.ReadUint256(&_tmp1); err != nil { + return err + } + _tmp0.Int = &_tmp1 + // IntNoPtr: + var _tmp2 uint256.Int + if err := dec.ReadUint256(&_tmp2); err != nil { + return err + } + _tmp0.IntNoPtr = _tmp2 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uints.in.txt b/rlp/rlpgen/testdata/uints.in.txt new file mode 100644 index 0000000000..8095da997d --- /dev/null +++ b/rlp/rlpgen/testdata/uints.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +type Test struct{ + A uint8 + B uint16 + C uint32 + D uint64 +} diff --git a/rlp/rlpgen/testdata/uints.out.txt b/rlp/rlpgen/testdata/uints.out.txt new file mode 100644 index 0000000000..17896dd305 --- /dev/null +++ b/rlp/rlpgen/testdata/uints.out.txt @@ -0,0 +1,53 @@ +package test + +import "github.com/tomochain/tomochain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteUint64(uint64(obj.A)) + w.WriteUint64(uint64(obj.B)) + w.WriteUint64(uint64(obj.C)) + w.WriteUint64(obj.D) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp0.A = _tmp1 + // B: + _tmp2, err := dec.Uint16() + if err != nil { + return err + } + _tmp0.B = _tmp2 + // C: + _tmp3, err := dec.Uint32() + if err != nil { + return err + } + _tmp0.C = _tmp3 + // D: + _tmp4, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.D = _tmp4 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/types.go b/rlp/rlpgen/types.go new file mode 100644 index 0000000000..ea7dc96d88 --- /dev/null +++ b/rlp/rlpgen/types.go @@ -0,0 +1,124 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "go/types" + "reflect" +) + +// typeReflectKind gives the reflect.Kind that represents typ. +func typeReflectKind(typ types.Type) reflect.Kind { + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + if k >= types.Bool && k <= types.Complex128 { + // value order matches for Bool..Complex128 + return reflect.Bool + reflect.Kind(k-types.Bool) + } + if k == types.String { + return reflect.String + } + if k == types.UnsafePointer { + return reflect.UnsafePointer + } + panic(fmt.Errorf("unhandled BasicKind %v", k)) + case *types.Array: + return reflect.Array + case *types.Chan: + return reflect.Chan + case *types.Interface: + return reflect.Interface + case *types.Map: + return reflect.Map + case *types.Pointer: + return reflect.Ptr + case *types.Signature: + return reflect.Func + case *types.Slice: + return reflect.Slice + case *types.Struct: + return reflect.Struct + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'. +func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string { + // Resolve type name. + typ := resolveUnderlying(vtyp) + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + switch { + case k == types.Bool: + return v + case k >= types.Uint && k <= types.Complex128: + return fmt.Sprintf("%s != 0", v) + case k == types.String: + return fmt.Sprintf(`%s != ""`, v) + default: + panic(fmt.Errorf("unhandled BasicKind %v", k)) + } + case *types.Array, *types.Struct: + return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify)) + case *types.Interface, *types.Pointer, *types.Signature: + return fmt.Sprintf("%s != nil", v) + case *types.Slice, *types.Map: + return fmt.Sprintf("len(%s) > 0", v) + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// isBigInt checks whether 'typ' is "math/big".Int. +func isBigInt(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "math/big" && name.Name() == "Int" +} + +// isUint256 checks whether 'typ' is "github.com/holiman/uint256".Int. +func isUint256(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "github.com/holiman/uint256" && name.Name() == "Int" +} + +// isByte checks whether the underlying type of 'typ' is uint8. +func isByte(typ types.Type) bool { + basic, ok := resolveUnderlying(typ).(*types.Basic) + return ok && basic.Kind() == types.Uint8 +} + +func resolveUnderlying(typ types.Type) types.Type { + for { + t := typ.Underlying() + if t == typ { + return t + } + typ = t + } +} diff --git a/rlp/safe.go b/rlp/safe.go new file mode 100644 index 0000000000..3c910337b6 --- /dev/null +++ b/rlp/safe.go @@ -0,0 +1,27 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build nacl || js || !cgo +// +build nacl js !cgo + +package rlp + +import "reflect" + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + return v.Slice(0, length).Bytes() +} diff --git a/rlp/typecache.go b/rlp/typecache.go index 3df799e1ec..c3244050bf 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -19,138 +19,222 @@ package rlp import ( "fmt" "reflect" - "strings" "sync" -) + "sync/atomic" -var ( - typeCacheMutex sync.RWMutex - typeCache = make(map[typekey]*typeinfo) + "github.com/tomochain/tomochain/rlp/internal/rlpstruct" ) +// typeinfo is an entry in the type cache. type typeinfo struct { - decoder - writer -} - -// represents struct tags -type tags struct { - // rlp:"nil" controls whether empty input results in a nil pointer. - nilOK bool - // rlp:"tail" controls whether this field swallows additional list - // elements. It can only be set for the last field, which must be - // of slice type. - tail bool - // rlp:"-" ignores fields. - ignored bool + decoder decoder + decoderErr error // error from makeDecoder + writer writer + writerErr error // error from makeWriter } +// typekey is the key of a type in typeCache. It includes the struct tags because +// they might generate a different decoder. type typekey struct { reflect.Type - // the key must include the struct tags because they - // might generate a different decoder. - tags + rlpstruct.Tags } type decoder func(*Stream, reflect.Value) error -type writer func(reflect.Value, *encbuf) error +type writer func(reflect.Value, *encBuffer) error + +var theTC = newTypeCache() + +type typeCache struct { + cur atomic.Value + + // This lock synchronizes writers. + mu sync.Mutex + next map[typekey]*typeinfo +} + +func newTypeCache() *typeCache { + c := new(typeCache) + c.cur.Store(make(map[typekey]*typeinfo)) + return c +} + +func cachedDecoder(typ reflect.Type) (decoder, error) { + info := theTC.info(typ) + return info.decoder, info.decoderErr +} + +func cachedWriter(typ reflect.Type) (writer, error) { + info := theTC.info(typ) + return info.writer, info.writerErr +} + +func (c *typeCache) info(typ reflect.Type) *typeinfo { + key := typekey{Type: typ} + if info := c.cur.Load().(map[typekey]*typeinfo)[key]; info != nil { + return info + } + + // Not in the cache, need to generate info for this type. + return c.generate(typ, rlpstruct.Tags{}) +} + +func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { + c.mu.Lock() + defer c.mu.Unlock() + + cur := c.cur.Load().(map[typekey]*typeinfo) + if info := cur[typekey{typ, tags}]; info != nil { + return info + } -func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { - typeCacheMutex.RLock() - info := typeCache[typekey{typ, tags}] - typeCacheMutex.RUnlock() - if info != nil { - return info, nil + // Copy cur to next. + c.next = make(map[typekey]*typeinfo, len(cur)+1) + for k, v := range cur { + c.next[k] = v } - // not in the cache, need to generate info for this type. - typeCacheMutex.Lock() - defer typeCacheMutex.Unlock() - return cachedTypeInfo1(typ, tags) + + // Generate. + info := c.infoWhileGenerating(typ, tags) + + // next -> cur + c.cur.Store(c.next) + c.next = nil + return info } -func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { +func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { key := typekey{typ, tags} - info := typeCache[key] - if info != nil { - // another goroutine got the write lock first - return info, nil + if info := c.next[key]; info != nil { + return info } - // put a dummmy value into the cache before generating. - // if the generator tries to lookup itself, it will get + // Put a dummy value into the cache before generating. + // If the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[key] = new(typeinfo) - info, err := genTypeInfo(typ, tags) - if err != nil { - // remove the dummy value if the generator fails - delete(typeCache, key) - return nil, err - } - *typeCache[key] = *info - return typeCache[key], err + info := new(typeinfo) + c.next[key] = info + info.generate(typ, tags) + return info } type field struct { - index int - info *typeinfo + index int + info *typeinfo + optional bool } +// structFields resolves the typeinfo of all public fields in a struct type. func structFields(typ reflect.Type) (fields []field, err error) { + // Convert fields to rlpstruct.Field. + var allStructFields []rlpstruct.Field for i := 0; i < typ.NumField(); i++ { - if f := typ.Field(i); f.PkgPath == "" { // exported - tags, err := parseStructTag(typ, i) - if err != nil { - return nil, err - } - if tags.ignored { - continue - } - info, err := cachedTypeInfo1(f.Type, tags) - if err != nil { - return nil, err - } - fields = append(fields, field{i, info}) + rf := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: rf.Name, + Index: i, + Exported: rf.PkgPath == "", + Tag: string(rf.Tag), + Type: *rtypeToStructType(rf.Type, nil), + }) + } + + // Filter/validate fields. + structFields, structTags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + if tagErr, ok := err.(rlpstruct.TagError); ok { + tagErr.StructType = typ.String() + return nil, tagErr } + return nil, err + } + + // Resolve typeinfo. + for i, sf := range structFields { + typ := typ.Field(sf.Index).Type + tags := structTags[i] + info := theTC.infoWhileGenerating(typ, tags) + fields = append(fields, field{sf.Index, info, tags.Optional}) } return fields, nil } -func parseStructTag(typ reflect.Type, fi int) (tags, error) { - f := typ.Field(fi) - var ts tags - for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { - switch t = strings.TrimSpace(t); t { - case "": - case "-": - ts.ignored = true - case "nil": - ts.nilOK = true - case "tail": - ts.tail = true - if fi != typ.NumField()-1 { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) - } - if f.Type.Kind() != reflect.Slice { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name) - } - default: - return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name) +// firstOptionalField returns the index of the first field with "optional" tag. +func firstOptionalField(fields []field) int { + for i, f := range fields { + if f.optional { + return i } } - return ts, nil + return len(fields) } -func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { - info = new(typeinfo) - if info.decoder, err = makeDecoder(typ, tags); err != nil { - return nil, err +type structFieldError struct { + typ reflect.Type + field int + err error +} + +func (e structFieldError) Error() string { + return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name) +} + +func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) { + i.decoder, i.decoderErr = makeDecoder(typ, tags) + i.writer, i.writerErr = makeWriter(typ, tags) +} + +// rtypeToStructType converts typ to rlpstruct.Type. +func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type { + k := typ.Kind() + if k == reflect.Invalid { + panic("invalid kind") } - if info.writer, err = makeWriter(typ, tags); err != nil { - return nil, err + + if prev := rec[typ]; prev != nil { + return prev // short-circuit for recursive types + } + if rec == nil { + rec = make(map[reflect.Type]*rlpstruct.Type) + } + + t := &rlpstruct.Type{ + Name: typ.String(), + Kind: k, + IsEncoder: typ.Implements(encoderInterface), + IsDecoder: typ.Implements(decoderInterface), + } + rec[typ] = t + if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr { + t.Elem = rtypeToStructType(typ.Elem(), rec) + } + return t +} + +// typeNilKind gives the RLP value kind for nil pointers to 'typ'. +func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind { + styp := rtypeToStructType(typ, nil) + + var nk rlpstruct.NilKind + if tags.NilOK { + nk = tags.NilKind + } else { + nk = styp.DefaultNilValue() + } + switch nk { + case rlpstruct.NilKindString: + return String + case rlpstruct.NilKindList: + return List + default: + panic("invalid nil kind value") } - return info, nil } func isUint(k reflect.Kind) bool { return k >= reflect.Uint && k <= reflect.Uintptr } + +func isByte(typ reflect.Type) bool { + return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) +} diff --git a/rlp/unsafe.go b/rlp/unsafe.go new file mode 100644 index 0000000000..2152ba35fc --- /dev/null +++ b/rlp/unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build !nacl && !js && cgo +// +build !nacl,!js,cgo + +package rlp + +import ( + "reflect" + "unsafe" +) + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + var s []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) + hdr.Data = v.UnsafeAddr() + hdr.Cap = length + hdr.Len = length + return s +} diff --git a/rpc/json.go b/rpc/json.go index 715f33ee16..e35a74118a 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -96,6 +96,10 @@ func (err *jsonError) ErrorCode() int { return err.Code } +func (err *jsonError) ErrorData() interface{} { + return err.Data +} + // NewCodec creates a new RPC server codec with support for JSON-RPC 2.0 based // on explicitly given encoding and decoding methods. func NewCodec(rwc io.ReadWriteCloser, encode, decode func(v interface{}) error) ServerCodec { diff --git a/tests/state_test.go b/tests/state_test.go index 7c8c5e9268..a6d23edacf 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -26,6 +26,9 @@ import ( ) func TestState(t *testing.T) { + if testing.Short() { + t.Skip("skipping testing in short mode") + } t.Parallel() st := new(testMatcher) @@ -50,13 +53,17 @@ func TestState(t *testing.T) { subtest := subtest key := fmt.Sprintf("%s/%d", subtest.Fork, subtest.Index) name := name + "/" + key - t.Run(key, func(t *testing.T) { - if subtest.Fork == "Constantinople" { - t.Skip("constantinople not supported yet") - } + + t.Run(key+"/trie", func(t *testing.T) { + withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { + _, err := test.Run(subtest, vmconfig, false) + return st.checkFailure(t, name+"/trie", err) + }) + }) + t.Run(key+"/snap", func(t *testing.T) { withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { - _, err := test.Run(subtest, vmconfig) - return st.checkFailure(t, name, err) + _, err := test.Run(subtest, vmconfig, true) + return st.checkFailure(t, name+"/snap", err) }) }) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index e532aa8a46..8e99c9b760 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -19,8 +19,8 @@ package tests import ( "encoding/hex" "encoding/json" + "errors" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" "strings" @@ -28,8 +28,9 @@ import ( "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" "github.com/tomochain/tomochain/core" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" - "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/core/state/snapshot" "github.com/tomochain/tomochain/core/vm" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/sha3" @@ -121,14 +122,14 @@ func (t *StateTest) Subtests() []StateSubtest { } // Run executes a specific subtest. -func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateDB, error) { +func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config, snapshotter bool) (*state.StateDB, error) { config, ok := Forks[subtest.Fork] if !ok { return nil, UnsupportedForkError{subtest.Fork} } block := t.genesis(config).ToBlock(nil) db := rawdb.NewMemoryDatabase() - statedb := MakePreState(db, t.json.Pre) + statedb := MakePreState(db, t.json.Pre, snapshotter) post := t.json.Post[subtest.Fork][subtest.Index] msg, err := t.json.Tx.toMessage(post) @@ -144,7 +145,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD snapshot := statedb.Snapshot() coinbase := &t.json.Env.Coinbase - if _, _, _, err := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil { + if _, err := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil { statedb.RevertToSnapshot(snapshot) } if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { @@ -161,9 +162,9 @@ func (t *StateTest) gasLimit(subtest StateSubtest) uint64 { return t.json.Tx.GasLimit[t.json.Post[subtest.Fork][subtest.Index].Indexes.Gas] } -func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB { +func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter bool) *state.StateDB { sdb := state.NewDatabase(db) - statedb, _ := state.New(common.Hash{}, sdb) + statedb, _ := state.New(common.Hash{}, sdb, nil) for addr, a := range accounts { statedb.SetCode(addr, a.Code) statedb.SetNonce(addr, a.Nonce) @@ -174,7 +175,12 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB } // Commit and re-open to start with a clean state. root, _ := statedb.Commit(false) - statedb, _ = state.New(root, sdb) + + var snaps *snapshot.Tree + if snapshotter { + snaps = snapshot.New(db, sdb.TrieDB(), 1, root, false) + } + statedb, _ = state.New(root, sdb, snaps) return statedb } @@ -190,7 +196,7 @@ func (t *StateTest) genesis(config *params.ChainConfig) *core.Genesis { } } -func (tx *stTransaction) toMessage(ps stPostState) (core.Message, error) { +func (tx *stTransaction) toMessage(ps stPostState) (*core.Message, error) { // Derive sender from private key if present. var from common.Address if len(tx.PrivateKey) > 0 { @@ -235,7 +241,21 @@ func (tx *stTransaction) toMessage(ps stPostState) (core.Message, error) { if err != nil { return nil, fmt.Errorf("invalid tx data %q", dataHex) } - msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, true, nil) + // If baseFee provided, set gasPrice to effectiveGasPrice. + gasPrice := tx.GasPrice + if gasPrice == nil { + return nil, errors.New("no gas price provided") + } + + msg := &core.Message{ + From: from, + To: to, + Nonce: tx.Nonce, + Value: value, + GasLimit: gasLimit, + GasPrice: gasPrice, + Data: data, + } return msg, nil } diff --git a/tests/vm_test.go b/tests/vm_test.go index 9e1f735436..234d73620c 100644 --- a/tests/vm_test.go +++ b/tests/vm_test.go @@ -17,15 +17,15 @@ package tests import ( - "github.com/tomochain/tomochain/common" "math/big" "testing" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/vm" ) func TestVM(t *testing.T) { - common.TIPTomoXCancellationFee=big.NewInt(100000000) + common.TIPTomoXCancellationFee = big.NewInt(100000000) t.Parallel() vmt := new(testMatcher) vmt.fails("^vmSystemOperationsTest.json/createNameRegistrator$", "fails without parallel execution") @@ -37,7 +37,10 @@ func TestVM(t *testing.T) { vmt.walk(t, vmTestDir, func(t *testing.T, name string, test *VMTest) { withTrace(t, test.json.Exec.GasLimit, func(vmconfig vm.Config) error { - return vmt.checkFailure(t, name, test.Run(vmconfig)) + return vmt.checkFailure(t, name+"/trie", test.Run(vmconfig, false)) + }) + withTrace(t, test.json.Exec.GasLimit, func(vmconfig vm.Config) error { + return vmt.checkFailure(t, name+"/snap", test.Run(vmconfig, true)) }) }) } diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 01c471af27..c2a56d7796 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -20,9 +20,10 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/tomochain/tomochain/core/rawdb" "math/big" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/common/hexutil" "github.com/tomochain/tomochain/common/math" @@ -78,9 +79,9 @@ type vmExecMarshaling struct { GasPrice *math.HexOrDecimal256 } -func (t *VMTest) Run(vmconfig vm.Config) error { +func (t *VMTest) Run(vmconfig vm.Config, snapshotter bool) error { db := rawdb.NewMemoryDatabase() - statedb := MakePreState(db, t.json.Pre) + statedb := MakePreState(db, t.json.Pre, snapshotter) ret, gasRemaining, err := t.exec(statedb, vmconfig) if t.json.GasRemaining == nil { diff --git a/tomox/tomox.go b/tomox/tomox.go index ae5b960e68..3f19e02f3e 100644 --- a/tomox/tomox.go +++ b/tomox/tomox.go @@ -573,6 +573,11 @@ func (tomox *TomoX) GetTradingState(block *types.Block, author common.Address) ( return tradingstate.New(root, tomox.StateCache) } + +func (tomox *TomoX) GetEmptyTradingState() (*tradingstate.TradingStateDB, error) { + return tradingstate.New(tradingstate.EmptyRoot, tomox.StateCache) +} + func (tomox *TomoX) GetStateCache() tradingstate.Database { return tomox.StateCache } diff --git a/tomox/tradingstate/database.go b/tomox/tradingstate/database.go index 56acf61ec6..e77b6be1a6 100644 --- a/tomox/tradingstate/database.go +++ b/tomox/tradingstate/database.go @@ -81,7 +81,7 @@ type Trie interface { func NewDatabase(db ethdb.Database) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabase(db), + db: trie.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), codeSizeCache: csc, } } diff --git a/tomox/tradingstate/tomox_trie.go b/tomox/tradingstate/tomox_trie.go index 908648def9..197e50b4c0 100644 --- a/tomox/tradingstate/tomox_trie.go +++ b/tomox/tradingstate/tomox_trie.go @@ -18,11 +18,11 @@ package tradingstate import ( "fmt" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/trie" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" ) // TomoXTrie wraps a trie with key hashing. In a secure trie, all @@ -78,10 +78,10 @@ func (t *TomoXTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(key) + return t.trie.Get(key) } -// TryGetBestLeftKey returns the value of max left leaf +// TryGetBestLeftKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestLeftKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestLeftKeyAndValue() @@ -91,7 +91,7 @@ func (t *TomoXTrie) TryGetAllLeftKeyAndValue(limit []byte) ([][]byte, [][]byte, return t.trie.TryGetAllLeftKeyAndValue(limit) } -// TryGetBestRightKey returns the value of max left leaf +// TryGetBestRightKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestRightKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestRightKeyAndValue() @@ -118,7 +118,7 @@ func (t *TomoXTrie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryUpdate(key, value []byte) error { - err := t.trie.TryUpdate(key, value) + err := t.trie.Update(key, value) if err != nil { return err } @@ -137,7 +137,7 @@ func (t *TomoXTrie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryDelete(key []byte) error { delete(t.getSecKeyCache(), string(key)) - return t.trie.TryDelete(key) + return t.trie.Delete(key) } // GetKey returns the sha3 preimage of a hashed key that was diff --git a/tomoxlending/lendingstate/database.go b/tomoxlending/lendingstate/database.go index d823602599..c27c41dcfe 100644 --- a/tomoxlending/lendingstate/database.go +++ b/tomoxlending/lendingstate/database.go @@ -80,7 +80,7 @@ type Trie interface { func NewDatabase(db ethdb.Database) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabase(db), + db: trie.NewDatabaseWithConfig(db, &trie.Config{Preimages: true}), codeSizeCache: csc, } } diff --git a/tomoxlending/lendingstate/lendingitem_test.go b/tomoxlending/lendingstate/lendingitem_test.go index b83c59ebee..564dffddf6 100644 --- a/tomoxlending/lendingstate/lendingitem_test.go +++ b/tomoxlending/lendingstate/lendingitem_test.go @@ -2,17 +2,18 @@ package lendingstate import ( "fmt" + "math/big" + "math/rand" + "os" + "testing" + "time" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/core/state" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/crypto/sha3" "github.com/tomochain/tomochain/rpc" - "math/big" - "math/rand" - "os" - "testing" - "time" ) func TestLendingItem_VerifyLendingSide(t *testing.T) { @@ -152,7 +153,7 @@ func SetCollateralDetail(statedb *state.StateDB, token common.Address, depositRa func TestVerifyBalance(t *testing.T) { db := rawdb.NewMemoryDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) relayer := common.HexToAddress("0x0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e") uAddr := common.HexToAddress("0xDeE6238780f98c0ca2c2C28453149bEA49a3Abc9") lendingToken := common.HexToAddress("0xd9bb01454c85247B2ef35BB5BE57384cC275a8cf") // USD diff --git a/tomoxlending/lendingstate/tomox_trie.go b/tomoxlending/lendingstate/tomox_trie.go index 8ff0a5633a..2852139ae0 100644 --- a/tomoxlending/lendingstate/tomox_trie.go +++ b/tomoxlending/lendingstate/tomox_trie.go @@ -18,11 +18,11 @@ package lendingstate import ( "fmt" - "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/trie" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/trie" ) // TomoXTrie wraps a trie with key hashing. In a secure trie, all @@ -78,16 +78,16 @@ func (t *TomoXTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(key) + return t.trie.Get(key) } -// TryGetBestLeftKey returns the value of max left leaf +// TryGetBestLeftKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestLeftKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestLeftKeyAndValue() } -// TryGetBestRightKey returns the value of max left leaf +// TryGetBestRightKeyAndValue returns the value of max left leaf // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryGetBestRightKeyAndValue() ([]byte, []byte, error) { return t.trie.TryGetBestRightKeyAndValue() @@ -114,7 +114,7 @@ func (t *TomoXTrie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryUpdate(key, value []byte) error { - err := t.trie.TryUpdate(key, value) + err := t.trie.Update(key, value) if err != nil { return err } @@ -133,7 +133,7 @@ func (t *TomoXTrie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *TomoXTrie) TryDelete(key []byte) error { delete(t.getSecKeyCache(), string(key)) - return t.trie.TryDelete(key) + return t.trie.Delete(key) } // GetKey returns the sha3 preimage of a hashed key that was diff --git a/trie/committer.go b/trie/committer.go index 78ed86bb4a..43a31381b9 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -22,8 +22,8 @@ import ( "sync" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" - "golang.org/x/crypto/sha3" ) // leafChanSize is the size of the leafCh. It's a pretty arbitrary number, to allow @@ -46,7 +46,7 @@ type leaf struct { // processed sequentially - onleaf will never be called in parallel or out of order. type committer struct { tmp sliceBuffer - sha keccakState + sha crypto.KeccakState onleaf LeafCallback leafCh chan *leaf @@ -57,7 +57,7 @@ var committerPool = sync.Pool{ New: func() interface{} { return &committer{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: crypto.NewKeccakState(), } }, } diff --git a/trie/database.go b/trie/database.go index bb2da07c29..446fda34c2 100644 --- a/trie/database.go +++ b/trie/database.go @@ -25,6 +25,7 @@ import ( "time" "github.com/VictoriaMetrics/fastcache" + "github.com/tomochain/tomochain/common" "github.com/tomochain/tomochain/ethdb" "github.com/tomochain/tomochain/log" @@ -65,6 +66,12 @@ const secureKeyPrefixLength = 11 // secureKeyLength is the length of the above prefix + 32byte hash. const secureKeyLength = secureKeyPrefixLength + 32 +// Config defines all necessary options for database. +type Config struct { + Cache int // Memory allowance (MB) to use for caching trie nodes in memory + Preimages bool // Flag whether the preimage of trie key is recorded +} + // Database is an intermediate write layer between the trie data structures and // the disk database. The aim is to accumulate trie writes in-memory and only // periodically flush a couple tries to disk, garbage collecting the remainder. @@ -74,6 +81,7 @@ const secureKeyLength = secureKeyPrefixLength + 32 // behind this split design is to provide read access to RPC handlers and sync // servers even while the trie is executing expensive garbage collection. type Database struct { + config *Config // Configuration for trie database diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes cleans *fastcache.Cache // GC friendly memory Cache of clean Node RLPs @@ -81,7 +89,7 @@ type Database struct { oldest common.Hash // Oldest tracked Node, flush-list head newest common.Hash // Newest tracked Node, flush-list tail - preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + preimages *preimageStore // The store for caching preimages gctime time.Duration // Time spent on garbage collection since last commit gcnodes uint64 // Nodes garbage collected since last commit @@ -106,7 +114,12 @@ type rawNode []byte func (n rawNode) Cache() (HashNode, bool) { panic("this should never end up in a live trie") } func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } -// rawFullNode represents only the useful data content of a full Node, with the +func (n rawNode) EncodeRLP(w io.Writer) error { + _, err := w.Write([]byte(n)) + return err +} + +// rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. type rawFullNode [17]Node @@ -184,7 +197,7 @@ func (n *cachedNode) obj(hash common.Hash) Node { // forChilds invokes the callback for all the tracked children of this Node, // both the implicit ones from inside the Node as well as the explicit ones -//from outside the Node. +// from outside the Node. func (n *cachedNode) forChilds(onChild func(hash common.Hash)) { for child := range n.children { onChild(child) @@ -277,26 +290,32 @@ func expandNode(hash HashNode, n Node) Node { // NewDatabase creates a new trie database to store ephemeral trie content before // its written out to disk or garbage collected. No read Cache is created, so all // data retrievals will hit the underlying disk database. -func NewDatabase(diskdb ethdb.KeyValueStore) *Database { - return NewDatabaseWithCache(diskdb, 0) +func NewDatabase(diskdb ethdb.Database) *Database { + return NewDatabaseWithConfig(diskdb, nil) } -// NewDatabaseWithCache creates a new trie database to store ephemeral trie content +// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // before its written out to disk or garbage collected. It also acts as a read Cache // for nodes loaded from disk. -func NewDatabaseWithCache(diskdb ethdb.KeyValueStore, cache int) *Database { +func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database { var cleans *fastcache.Cache - if cache > 0 { - cleans = fastcache.New(cache * 1024 * 1024) + if config != nil && config.Cache > 0 { + cleans = fastcache.New(config.Cache * 1024 * 1024) + } + var preimages *preimageStore + if config != nil && config.Preimages { + preimages = newPreimageStore(diskdb) } - return &Database{ + db := &Database{ diskdb: diskdb, cleans: cleans, dirties: map[common.Hash]*cachedNode{{}: { children: make(map[common.Hash]uint16), }}, - preimages: make(map[common.Hash][]byte), + preimages: preimages, } + + return db } // DiskDB retrieves the persistent storage backing the trie database. @@ -352,11 +371,12 @@ func (db *Database) insert(hash common.Hash, size int, node Node) { // yet unknown. The method will make a copy of the slice. // // Note, this method assumes that the database's Lock is held! +// This function's still be kept because of TomoX tries func (db *Database) InsertPreimage(hash common.Hash, preimage []byte) { - if _, ok := db.preimages[hash]; ok { + if _, ok := db.preimages.preimages[hash]; ok { return } - db.preimages[hash] = common.CopyBytes(preimage) + db.preimages.preimages[hash] = common.CopyBytes(preimage) db.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) } @@ -440,7 +460,7 @@ func (db *Database) Node(hash common.Hash) ([]byte, error) { func (db *Database) Preimage(hash common.Hash) ([]byte, error) { // Retrieve the Node from Cache if available db.Lock.RLock() - preimage := db.preimages[hash] + preimage := db.preimages.preimages[hash] db.Lock.RUnlock() if preimage != nil { @@ -607,7 +627,7 @@ func (db *Database) Cap(limit common.StorageSize) error { // leave for later to deduplicate writes. flushPreimages := db.preimagesSize > 4*1024*1024 if flushPreimages { - for hash, preimage := range db.preimages { + for hash, preimage := range db.preimages.preimages { copy(keyBuf[secureKeyPrefixLength:], hash[:]) if err := batch.Put(keyBuf[:], preimage); err != nil { log.Error("Failed to commit Preimage from trie database", "err", err) @@ -656,7 +676,7 @@ func (db *Database) Cap(limit common.StorageSize) error { defer db.Lock.Unlock() if flushPreimages { - db.preimages = make(map[common.Hash][]byte) + db.preimages.preimages = make(map[common.Hash][]byte) db.preimagesSize = 0 } for db.oldest != oldest { @@ -706,26 +726,28 @@ func (db *Database) Commit(node common.Hash, report bool) error { copy(keyBuf[:], secureKeyPrefix) // Move all of the accumulated preimages into a write batch - for hash, preimage := range db.preimages { - copy(keyBuf[secureKeyPrefixLength:], hash[:]) - if err := batch.Put(keyBuf[:], preimage); err != nil { - log.Error("Failed to commit Preimage from trie database", "err", err) - return err - } - // If the batch is too large, flush to disk - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { + if db.preimages != nil { + for hash, preimage := range db.preimages.preimages { + copy(keyBuf[secureKeyPrefixLength:], hash[:]) + if err := batch.Put(keyBuf[:], preimage); err != nil { + log.Error("Failed to commit Preimage from trie database", "err", err) return err } - batch.Reset() + // If the batch is too large, flush to disk + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } } + // Since we're going to replay trie Node writes into the clean Cache, flush out + // any batched pre-images before continuing. + if err := batch.Write(); err != nil { + return err + } + batch.Reset() } - // Since we're going to replay trie Node writes into the clean Cache, flush out - // any batched pre-images before continuing. - if err := batch.Write(); err != nil { - return err - } - batch.Reset() // Move the trie itself into the batch, flushing if enough data is accumulated nodes, storage := len(db.dirties), db.dirtiesSize @@ -747,10 +769,6 @@ func (db *Database) Commit(node common.Hash, report bool) error { batch.Replay(uncacher) batch.Reset() - // Reset the storage counters and bumpd metrics - db.preimages = make(map[common.Hash][]byte) - db.preimagesSize = 0 - memcacheCommitTimeTimer.Update(time.Since(start)) memcacheCommitSizeMeter.Mark(int64(storage - db.dirtiesSize)) memcacheCommitNodesMeter.Mark(int64(nodes - len(db.dirties))) @@ -785,6 +803,7 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane if err != nil { return err } + if err := batch.Put(hash[:], node.rlp()); err != nil { return err } @@ -794,9 +813,12 @@ func (db *Database) commit(hash common.Hash, batch ethdb.Batch, uncacher *cleane return err } db.Lock.Lock() - batch.Replay(uncacher) + err := batch.Replay(uncacher) batch.Reset() db.Lock.Unlock() + if err != nil { + return err + } } return nil } @@ -810,7 +832,7 @@ type cleaner struct { // Put reacts to database writes and implements dirty data uncaching. This is the // post-processing step of a commit operation where the already persisted trie is // removed from the dirty Cache and moved into the clean Cache. The reason behind -// the two-phase commit is to ensure ensure data availability while moving from +// the two-phase commit is to ensure data availability while moving from // memory to disk. func (c *cleaner) Put(key []byte, rlp []byte) error { hash := common.BytesToHash(key) diff --git a/trie/database_test.go b/trie/database_test.go index ed6b58fdc5..126923b12c 100644 --- a/trie/database_test.go +++ b/trie/database_test.go @@ -20,13 +20,13 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/core/rawdb" ) // Tests that the trie database returns a missing trie Node error if attempting // to retrieve the meta root. func TestDatabaseMetarootFetch(t *testing.T) { - db := NewDatabase(memorydb.New()) + db := NewDatabase(rawdb.NewMemoryDatabase()) if _, err := db.Node(common.Hash{}); err == nil { t.Fatalf("metaroot retrieval succeeded") } diff --git a/trie/hasher.go b/trie/hasher.go index 8a2ea18068..d4a36dd5ed 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -1,4 +1,4 @@ -// Copyright 2019 The go-ethereum Authors +// Copyright 2016 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -17,21 +17,12 @@ package trie import ( - "hash" "sync" + "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/rlp" - "golang.org/x/crypto/sha3" ) -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - type sliceBuffer []byte func (b *sliceBuffer) Write(data []byte) (n int, err error) { @@ -46,17 +37,19 @@ func (b *sliceBuffer) Reset() { // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { - sha keccakState - tmp sliceBuffer - parallel bool // Whether to use paralallel threads when hashing + sha crypto.KeccakState + tmp []byte + encbuf rlp.EncoderBuffer + parallel bool // Whether to use parallel threads when hashing } // hasherPool holds pureHashers var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ - tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + tmp: make([]byte, 0, 550), // cap is as large as a full fullNode. + sha: crypto.NewKeccakState(), + encbuf: rlp.NewEncoderBuffer(nil), } }, } @@ -71,14 +64,14 @@ func returnHasherToPool(h *hasher) { hasherPool.Put(h) } -// hash collapses a Node down into a hash Node, also returning a copy of the -// original Node initialized with the computed hash to replace the original one. +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialized with the computed hash to replace the original one. func (h *hasher) hash(n Node, force bool) (hashed Node, cached Node) { - // We're not storing the Node, just hashing, use available cached data + // Return the cached hash if it's available if hash, _ := n.Cache(); hash != nil { return hash, n } - // Trie not processed yet or needs storage, walk the children + // Trie not processed yet, walk the children switch n := n.(type) { case *ShortNode: collapsed, cached := h.hashShortNodeChildren(n) @@ -106,11 +99,11 @@ func (h *hasher) hash(n Node, force bool) (hashed Node, cached Node) { } } -// hashShortNodeChildren collapses the short Node. The returned collapsed Node +// hashShortNodeChildren collapses the short node. The returned collapsed node // holds a live reference to the Key, and must not be modified. // The cached func (h *hasher) hashShortNodeChildren(n *ShortNode) (collapsed, cached *ShortNode) { - // Hash the short Node's child, caching the newly hashed subtree + // Hash the short node's child, caching the newly hashed subtree collapsed, cached = n.copy(), n.copy() // Previously, we did copy this one. We don't seem to need to actually // do that, since we don't overwrite/reuse keys @@ -125,7 +118,7 @@ func (h *hasher) hashShortNodeChildren(n *ShortNode) (collapsed, cached *ShortNo } func (h *hasher) hashFullNodeChildren(n *FullNode) (collapsed *FullNode, cached *FullNode) { - // Hash the full Node's children, caching the newly hashed subtrees + // Hash the full node's children, caching the newly hashed subtrees cached = n.copy() collapsed = n.copy() if h.parallel { @@ -156,35 +149,46 @@ func (h *hasher) hashFullNodeChildren(n *FullNode) (collapsed *FullNode, cached return collapsed, cached } -// shortnodeToHash creates a HashNode from a ShortNode. The supplied shortnode +// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode // should have hex-type Key, which will be converted (without modification) // into compact form for RLP encoding. // If the rlp data is smaller than 32 bytes, `nil` is returned. func (h *hasher) shortnodeToHash(n *ShortNode, force bool) Node { - h.tmp.Reset() - if err := rlp.Encode(&h.tmp, n); err != nil { - panic("encode error: " + err.Error()) - } + n.encode(h.encbuf) + enc := h.encodedBytes() - if len(h.tmp) < 32 && !force { + if len(enc) < 32 && !force { return n // Nodes smaller than 32 bytes are stored inside their parent } - return h.hashData(h.tmp) + return h.hashData(enc) } -// shortnodeToHash is used to creates a HashNode from a set of hashNodes, (which +// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which // may contain nil values) func (h *hasher) fullnodeToHash(n *FullNode, force bool) Node { - h.tmp.Reset() - // Generate the RLP encoding of the Node - if err := n.EncodeRLP(&h.tmp); err != nil { - panic("encode error: " + err.Error()) - } + n.encode(h.encbuf) + enc := h.encodedBytes() - if len(h.tmp) < 32 && !force { + if len(enc) < 32 && !force { return n // Nodes smaller than 32 bytes are stored inside their parent } - return h.hashData(h.tmp) + return h.hashData(enc) +} + +// encodedBytes returns the result of the last encoding operation on h.encbuf. +// This also resets the encoder buffer. +// +// All node encoding must be done like this: +// +// node.encode(h.encbuf) +// enc := h.encodedBytes() +// +// This convention exists because node.encode can only be inlined/escape-analyzed when +// called on a concrete receiver type. +func (h *hasher) encodedBytes() []byte { + h.tmp = h.encbuf.AppendToBytes(h.tmp[:0]) + h.encbuf.Reset(nil) + return h.tmp } // hashData hashes the provided data @@ -197,8 +201,8 @@ func (h *hasher) hashData(data []byte) HashNode { } // proofHash is used to construct trie proofs, and returns the 'collapsed' -// Node (for later RLP encoding) aswell as the hashed Node -- unless the -// Node is smaller than 32 bytes, in which case it will be returned as is. +// node (for later RLP encoding) as well as the hashed node -- unless the +// node is smaller than 32 bytes, in which case it will be returned as is. // This method does not do anything on value- or hash-nodes. func (h *hasher) proofHash(original Node) (collapsed, hashed Node) { switch n := original.(type) { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 26d48c95cd..b93d664220 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -23,7 +23,7 @@ import ( "testing" "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/ethdb/memorydb" + "github.com/tomochain/tomochain/core/rawdb" ) func TestIterator(t *testing.T) { @@ -292,7 +292,7 @@ func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueA func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } func testIteratorContinueAfterError(t *testing.T, memonly bool) { - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) tr, _ := New(common.Hash{}, triedb) @@ -383,7 +383,7 @@ func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { // Commit test trie to Db, then remove the Node containing "bars". - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) ctr, _ := New(common.Hash{}, triedb) diff --git a/trie/node.go b/trie/node.go index ffb2f18116..fbbe293413 100644 --- a/trie/node.go +++ b/trie/node.go @@ -30,6 +30,7 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b type Node interface { fstring(string) string Cache() (HashNode, bool) + encode(w rlp.EncoderBuffer) } type ( @@ -52,16 +53,9 @@ var nilValueNode = ValueNode(nil) // EncodeRLP encodes a full Node into the consensus RLP format. func (n *FullNode) EncodeRLP(w io.Writer) error { - var nodes [17]Node - - for i, child := range &n.Children { - if child != nil { - nodes[i] = child - } else { - nodes[i] = nilValueNode - } - } - return rlp.Encode(w, nodes) + eb := rlp.NewEncoderBuffer(w) + n.encode(eb) + return eb.Flush() } func (n *FullNode) copy() *FullNode { copy := *n; return © } diff --git a/trie/node_enc.go b/trie/node_enc.go new file mode 100644 index 0000000000..b987abfbf5 --- /dev/null +++ b/trie/node_enc.go @@ -0,0 +1,72 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "github.com/tomochain/tomochain/rlp" +) + +func nodeToBytes(n Node) []byte { + w := rlp.NewEncoderBuffer(nil) + n.encode(w) + result := w.ToBytes() + w.Flush() + return result +} + +func (n *FullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n.Children { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *ShortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n HashNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n ValueNode) encode(w rlp.EncoderBuffer) { + w.WriteBytes(n) +} + +func (n rawNode) encode(w rlp.EncoderBuffer) { + w.Write(n) +} + +func (n rawShortNode) encode(w rlp.EncoderBuffer) { + panic("this should never end up in a live trie") +} + +func (n rawFullNode) encode(w rlp.EncoderBuffer) { + panic("this should never end up in a live trie") +} diff --git a/trie/preimages.go b/trie/preimages.go new file mode 100644 index 0000000000..760f2290f4 --- /dev/null +++ b/trie/preimages.go @@ -0,0 +1,94 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "sync" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/ethdb" +) + +// preimageStore is the store for caching preimages of node key. +type preimageStore struct { + lock sync.RWMutex + disk ethdb.Database + preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + preimagesSize common.StorageSize // Storage size of the preimages cache +} + +// newPreimageStore initializes the store for caching preimages. +func newPreimageStore(disk ethdb.Database) *preimageStore { + return &preimageStore{ + disk: disk, + preimages: make(map[common.Hash][]byte), + } +} + +// insertPreimage writes a new trie node pre-image to the memory database if it's +// yet unknown. The method will NOT make a copy of the slice, only use if the +// preimage will NOT be changed later on. +func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) { + store.lock.Lock() + defer store.lock.Unlock() + + for hash, preimage := range preimages { + if _, ok := store.preimages[hash]; ok { + continue + } + store.preimages[hash] = preimage + store.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) + } +} + +// preimage retrieves a cached trie node pre-image from memory. If it cannot be +// found cached, the method queries the persistent database for the content. +func (store *preimageStore) preimage(hash common.Hash) []byte { + store.lock.RLock() + preimage := store.preimages[hash] + store.lock.RUnlock() + + if preimage != nil { + return preimage + } + return rawdb.ReadPreimage(store.disk, hash) +} + +// commit flushes the cached preimages into the disk. +func (store *preimageStore) commit(force bool) error { + store.lock.Lock() + defer store.lock.Unlock() + + if store.preimagesSize <= 4*1024*1024 && !force { + return nil + } + if err := rawdb.WritePreimages(store.disk, 0, store.preimages); err != nil { + return err + } + + store.preimages, store.preimagesSize = make(map[common.Hash][]byte), 0 + return nil +} + +// size returns the current storage size of accumulated preimages. +func (store *preimageStore) size() common.StorageSize { + store.lock.RLock() + defer store.lock.RUnlock() + + return store.preimagesSize +} diff --git a/trie/proof.go b/trie/proof.go index 9e4082a27e..28320e8a06 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -22,8 +22,8 @@ import ( "fmt" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/ethdb" - "github.com/tomochain/tomochain/ethdb/memorydb" "github.com/tomochain/tomochain/log" "github.com/tomochain/tomochain/rlp" ) @@ -395,11 +395,11 @@ func hasRightElement(node Node, key []byte) bool { // Expect the normal case, this function can also be used to verify the following // range proofs(note this function doesn't accept zero element proof): // -// - All elements proof. In this case the left and right proof can be nil, but the -// range should be all the leaves in the trie. +// - All elements proof. In this case the left and right proof can be nil, but the +// range should be all the leaves in the trie. // -// - One element proof. In this case no matter the left edge proof is a non-existent -// proof or not, we can always verify the correctness of the proof. +// - One element proof. In this case no matter the left edge proof is a non-existent +// proof or not, we can always verify the correctness of the proof. // // Except returning the error to indicate the proof is valid or not, the function will // also return a flag to indicate whether there exists more accounts/slots in the trie. @@ -419,15 +419,12 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu // Special case, there is no edge proof at all. The given range is expected // to be the whole leaf-set in the trie. if firstProof == nil && lastProof == nil { - emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New())) - if err != nil { - return err, false - } + tr := NewStackTrie(nil) for index, key := range keys { - emptytrie.TryUpdate(key, values[index]) + tr.Update(key, values[index]) } - if emptytrie.Hash() != rootHash { - return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash()), false + if have, want := tr.Hash(), rootHash; have != want { + return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()), false } return nil, false // no more element. } @@ -464,9 +461,10 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu } // Rebuild the trie with the leave stream, the shape of trie // should be same with the original one. - newtrie := &Trie{root: root, Db: NewDatabase(memorydb.New())} + + newtrie := &Trie{root: root, Db: NewDatabase(rawdb.NewMemoryDatabase())} for index, key := range keys { - newtrie.TryUpdate(key, values[index]) + newtrie.Update(key, values[index]) } if newtrie.Hash() != rootHash { return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash()), false diff --git a/trie/secure_trie.go b/trie/secure_trie.go index f62d3d06de..cbffd559e3 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -17,10 +17,9 @@ package trie import ( - "fmt" - "github.com/tomochain/tomochain/common" - "github.com/tomochain/tomochain/log" + "github.com/tomochain/tomochain/core/types" + "github.com/tomochain/tomochain/rlp" ) // SecureTrie wraps a trie with key hashing. In a secure trie, all @@ -35,6 +34,7 @@ import ( // SecureTrie is not safe for concurrent use. type SecureTrie struct { trie Trie + preimages *preimageStore hashKeyBuf [common.HashLength]byte secKeyCache map[string][]byte secKeyCacheOwner *SecureTrie // Pointer to self, replace the key Cache on mismatch @@ -50,7 +50,7 @@ type SecureTrie struct { // Accessing the trie loads nodes from the database or Node pool on demand. // Loaded nodes are kept around until their 'Cache generation' expires. // A new Cache generation is created by each call to Commit. -// cachelimit sets the number of past Cache generations to keep. +// cache limit sets the number of past Cache generations to keep. func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { if db == nil { panic("trie.NewSecure called without a database") @@ -59,49 +59,83 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { if err != nil { return nil, err } - return &SecureTrie{trie: *trie}, nil + return &SecureTrie{trie: *trie, preimages: db.preimages}, nil } -// Get returns the value for key stored in the trie. +// MustGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -func (t *SecureTrie) Get(key []byte) []byte { - res, err := t.TryGet(key) - if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// +// This function will omit any encountered error but just +// print out an error message. +func (t *SecureTrie) MustGet(key []byte) []byte { + return t.trie.MustGet(t.hashKey(key)) +} + +// GetStorage attempts to retrieve a storage slot with provided account address +// and slot key. The value bytes must not be modified by the caller. +// If the specified storage slot is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { + enc, err := t.trie.Get(t.hashKey(key)) + if err != nil || len(enc) == 0 { + return nil, err } - return res + _, content, _, err := rlp.Split(enc) + return content, err } -// TryGet returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(t.hashKey(key)) +// GetAccount attempts to retrieve an account with provided account address. +// If the specified account is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) GetAccount(address common.Address) (*types.StateAccount, error) { + res, err := t.trie.Get(t.hashKey(address.Bytes())) + if res == nil || err != nil { + return nil, err + } + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err } -// Update associates key with value in the trie. Subsequent calls to +// GetAccountByHash does the same thing as GetAccount, however it expects an +// account hash that is the hash of address. This constitutes an abstraction +// leak, since the client code needs to know the key format. +func (t *SecureTrie) GetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { + res, err := t.trie.Get(addrHash.Bytes()) + if res == nil || err != nil { + return nil, err + } + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err +} + +// MustUpdate associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. -func (t *SecureTrie) Update(key, value []byte) { - if err := t.TryUpdate(key, value); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - } +// +// This function will omit any encountered error but just print out an +// error message. +func (t *SecureTrie) MustUpdate(key, value []byte) { + hk := t.hashKey(key) + t.trie.MustUpdate(hk, value) + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) } -// TryUpdate associates key with value in the trie. Subsequent calls to +// UpdateStorage associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. // -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *SecureTrie) TryUpdate(key, value []byte) error { +// If a node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) UpdateStorage(_ common.Address, key, value []byte) error { hk := t.hashKey(key) - err := t.trie.TryUpdate(hk, value) + err := t.trie.Update(hk, value) if err != nil { return err } @@ -109,19 +143,47 @@ func (t *SecureTrie) TryUpdate(key, value []byte) error { return nil } -// Delete removes any existing value for key from the trie. -func (t *SecureTrie) Delete(key []byte) { - if err := t.TryDelete(key); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// UpdateAccount will abstract the write of an account to the secure trie. + +func (t *SecureTrie) UpdateAccount(address common.Address, acc *types.StateAccount) error { + hk := t.hashKey(address.Bytes()) + data, err := rlp.EncodeToBytes(acc) + if err != nil { + return err } + if err := t.trie.Update(hk, data); err != nil { + return err + } + t.getSecKeyCache()[string(hk)] = address.Bytes() + return nil +} + +func (t *SecureTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { + return nil +} + +// MustDelete removes any existing value for key from the trie. This function +// will omit any encountered error but just print out an error message. +func (t *SecureTrie) MustDelete(key []byte) { + hk := t.hashKey(key) + delete(t.getSecKeyCache(), string(hk)) + t.trie.MustDelete(hk) } -// TryDelete removes any existing value for key from the trie. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *SecureTrie) TryDelete(key []byte) error { +// DeleteStorage removes any existing storage slot from the trie. +// If the specified trie node is not in the trie, nothing will be changed. +// If a node is not found in the database, a MissingNodeError is returned. +func (t *SecureTrie) DeleteStorage(_ common.Address, key []byte) error { hk := t.hashKey(key) delete(t.getSecKeyCache(), string(hk)) - return t.trie.TryDelete(hk) + return t.trie.Delete(hk) +} + +// DeleteAccount abstracts an account deletion from the trie. +func (t *SecureTrie) DeleteAccount(address common.Address) error { + hk := t.hashKey(address.Bytes()) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.Delete(hk) } // GetKey returns the sha3 Preimage of a hashed key that was @@ -130,8 +192,10 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } - key, _ := t.trie.Db.Preimage(common.BytesToHash(shaKey)) - return key + if t.preimages == nil { + return nil + } + return t.preimages.preimage(common.BytesToHash(shaKey)) } // Commit writes all nodes and the secure hash pre-images to the trie's database. @@ -142,12 +206,15 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) { // Write all the pre-images to the actual disk database if len(t.getSecKeyCache()) > 0 { - t.trie.Db.Lock.Lock() - for hk, key := range t.secKeyCache { - t.trie.Db.InsertPreimage(common.BytesToHash([]byte(hk)), key) + if t.preimages != nil { + t.trie.Db.Lock.Lock() + preimages := make(map[common.Hash][]byte) + for hk, key := range t.secKeyCache { + preimages[common.BytesToHash([]byte(hk))] = key + } + t.preimages.insertPreimage(preimages) + t.trie.Db.Lock.Unlock() } - t.trie.Db.Lock.Unlock() - t.secKeyCache = make(map[string][]byte) } // Commit the trie to its intermediate Node database @@ -162,8 +229,11 @@ func (t *SecureTrie) Hash() common.Hash { // Copy returns a copy of SecureTrie. func (t *SecureTrie) Copy() *SecureTrie { - cpy := *t - return &cpy + return &SecureTrie{ + trie: *t.trie.Copy(), + preimages: t.preimages, + secKeyCache: t.secKeyCache, + } } // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index a015ffcff6..bc17b2ca40 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -23,19 +23,19 @@ import ( "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" - "github.com/tomochain/tomochain/ethdb/memorydb" ) func newEmptySecure() *SecureTrie { - trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New())) + trie, _ := NewSecure(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) return trie } // makeTestSecureTrie creates a large enough secure trie for testing. func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie, _ := NewSecure(common.Hash{}, triedb) // Fill it with some arbitrary data @@ -44,17 +44,17 @@ func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) // Add some other data to inflate the trie for j := byte(3); j < 13; j++ { key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} content[string(key)] = val - trie.Update(key, val) + trie.MustUpdate(key, val) } } trie.Commit(nil) @@ -77,9 +77,9 @@ func TestSecureDelete(t *testing.T) { } for _, val := range vals { if val.v != "" { - trie.Update([]byte(val.k), []byte(val.v)) + trie.MustUpdate([]byte(val.k), []byte(val.v)) } else { - trie.Delete([]byte(val.k)) + trie.MustDelete([]byte(val.k)) } } hash := trie.Hash() @@ -91,13 +91,13 @@ func TestSecureDelete(t *testing.T) { func TestSecureGetKey(t *testing.T) { trie := newEmptySecure() - trie.Update([]byte("foo"), []byte("bar")) + trie.MustUpdate([]byte("foo"), []byte("bar")) key := []byte("foo") value := []byte("bar") seckey := crypto.Keccak256(key) - if !bytes.Equal(trie.Get(key), value) { + if !bytes.Equal(trie.MustGet(key), value) { t.Errorf("Get did not return bar") } if k := trie.GetKey(seckey); !bytes.Equal(k, key) { @@ -125,15 +125,15 @@ func TestSecureTrieConcurrency(t *testing.T) { for j := byte(0); j < 255; j++ { // Map the same data under multiple keys key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) // Add some other data to inflate the trie for k := byte(3); k < 13; k++ { key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} - tries[index].Update(key, val) + tries[index].MustUpdate(key, val) } } tries[index].Commit(nil) diff --git a/trie/stacktrie.go b/trie/stacktrie.go new file mode 100644 index 0000000000..48417e556c --- /dev/null +++ b/trie/stacktrie.go @@ -0,0 +1,533 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bufio" + "bytes" + "encoding/gob" + "errors" + "io" + "sync" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/log" +) + +var ErrCommitDisabled = errors.New("no database for committing") + +var stPool = sync.Pool{ + New: func() interface{} { + return NewStackTrie(nil) + }, +} + +// NodeWriteFunc is used to provide all information of a dirty node for committing +// so that callers can flush nodes into database with desired scheme. +type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob []byte) + +func stackTrieFromPool(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { + st := stPool.Get().(*StackTrie) + st.owner = owner + st.writeFn = writeFn + return st +} + +func returnToPool(st *StackTrie) { + st.Reset() + stPool.Put(st) +} + +// StackTrie is a trie implementation that expects keys to be inserted +// in order. Once it determines that a subtree will no longer be inserted +// into, it will hash it and free up the memory it uses. +type StackTrie struct { + owner common.Hash // the owner of the trie + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (leaf|ext) node + children [16]*StackTrie // list of children (for branch and exts) + writeFn NodeWriteFunc // function for committing nodes, can be nil +} + +// NewStackTrie allocates and initializes an empty trie. +func NewStackTrie(writeFn NodeWriteFunc) *StackTrie { + return &StackTrie{ + nodeType: emptyNode, + writeFn: writeFn, + } +} + +// NewStackTrieWithOwner allocates and initializes an empty trie, but with +// the additional owner field. +func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie { + return &StackTrie{ + owner: owner, + nodeType: emptyNode, + writeFn: writeFn, + } +} + +// NewFromBinary initialises a serialized stacktrie with the given db. +func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { + var st StackTrie + if err := st.UnmarshalBinary(data); err != nil { + return nil, err + } + // If a database is used, we need to recursively add it to every child + if writeFn != nil { + st.setWriter(writeFn) + } + return &st, nil +} + +// MarshalBinary implements encoding.BinaryMarshaler +func (st *StackTrie) MarshalBinary() (data []byte, err error) { + var ( + b bytes.Buffer + w = bufio.NewWriter(&b) + ) + if err := gob.NewEncoder(w).Encode(struct { + Owner common.Hash + NodeType uint8 + Val []byte + Key []byte + }{ + st.owner, + st.nodeType, + st.val, + st.key, + }); err != nil { + return nil, err + } + for _, child := range st.children { + if child == nil { + w.WriteByte(0) + continue + } + w.WriteByte(1) + if childData, err := child.MarshalBinary(); err != nil { + return nil, err + } else { + w.Write(childData) + } + } + w.Flush() + return b.Bytes(), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (st *StackTrie) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + return st.unmarshalBinary(r) +} + +func (st *StackTrie) unmarshalBinary(r io.Reader) error { + var dec struct { + Owner common.Hash + NodeType uint8 + Val []byte + Key []byte + } + if err := gob.NewDecoder(r).Decode(&dec); err != nil { + return err + } + st.owner = dec.Owner + st.nodeType = dec.NodeType + st.val = dec.Val + st.key = dec.Key + + var hasChild = make([]byte, 1) + for i := range st.children { + if _, err := r.Read(hasChild); err != nil { + return err + } else if hasChild[0] == 0 { + continue + } + var child StackTrie + if err := child.unmarshalBinary(r); err != nil { + return err + } + st.children[i] = &child + } + return nil +} + +func (st *StackTrie) setWriter(writeFn NodeWriteFunc) { + st.writeFn = writeFn + for _, child := range st.children { + if child != nil { + child.setWriter(writeFn) + } + } +} + +func newLeaf(owner common.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) + st.nodeType = leafNode + st.key = append(st.key, key...) + st.val = val + return st +} + +func newExt(owner common.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie { + st := stackTrieFromPool(writeFn, owner) + st.nodeType = extNode + st.key = append(st.key, key...) + st.children[0] = child + return st +} + +// List all values that StackTrie#nodeType can hold +const ( + emptyNode = iota + branchNode + extNode + leafNode + hashedNode +) + +// Update inserts a (key, value) pair into the stack trie. +func (st *StackTrie) Update(key, value []byte) error { + k := keybytesToHex(key) + if len(value) == 0 { + panic("deletion not supported") + } + st.insert(k[:len(k)-1], value, nil) + return nil +} + +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (st *StackTrie) MustUpdate(key, value []byte) { + if err := st.Update(key, value); err != nil { + log.Error("Unhandled trie error in StackTrie.Update", "err", err) + } +} + +func (st *StackTrie) Reset() { + st.owner = common.Hash{} + st.writeFn = nil + st.key = st.key[:0] + st.val = nil + for i := range st.children { + st.children[i] = nil + } + st.nodeType = emptyNode +} + +// Helper function that, given a full key, determines the index +// at which the chunk pointed by st.keyOffset is different from +// the same chunk in the full key. +func (st *StackTrie) getDiffIndex(key []byte) int { + for idx, nibble := range st.key { + if nibble != key[idx] { + return idx + } + } + return len(st.key) +} + +// Helper function to that inserts a (key, value) pair into +// the trie. +func (st *StackTrie) insert(key, value []byte, prefix []byte) { + switch st.nodeType { + case branchNode: /* Branch */ + idx := int(key[0]) + + // Unresolve elder siblings + for i := idx - 1; i >= 0; i-- { + if st.children[i] != nil { + if st.children[i].nodeType != hashedNode { + st.children[i].hash(append(prefix, byte(i))) + } + break + } + } + + // Add new child + if st.children[idx] == nil { + st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn) + } else { + st.children[idx].insert(key[1:], value, append(prefix, key[0])) + } + + case extNode: /* Ext */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Check if chunks are identical. If so, recurse into + // the child node. Otherwise, the key has to be split + // into 1) an optional common prefix, 2) the fullnode + // representing the two differing path, and 3) a leaf + // for each of the differentiated subtrees. + if diffidx == len(st.key) { + // Ext key and key segment are identical, recurse into + // the child node. + st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...)) + return + } + // Save the original part. Depending if the break is + // at the extension's last byte or not, create an + // intermediate extension or use the extension's child + // node directly. + var n *StackTrie + if diffidx < len(st.key)-1 { + // Break on the non-last byte, insert an intermediate + // extension. The path prefix of the newly-inserted + // extension should also contain the different byte. + n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn) + n.hash(append(prefix, st.key[:diffidx+1]...)) + } else { + // Break on the last byte, no need to insert + // an extension node: reuse the current node. + // The path prefix of the original part should + // still be same. + n = st.children[0] + n.hash(append(prefix, st.key...)) + } + var p *StackTrie + if diffidx == 0 { + // the break is on the first byte, so + // the current node is converted into + // a branch node. + st.children[0] = nil + p = st + st.nodeType = branchNode + } else { + // the common prefix is at least one byte + // long, insert a new intermediate branch + // node. + st.children[0] = stackTrieFromPool(st.writeFn, st.owner) + st.children[0].nodeType = branchNode + p = st.children[0] + } + // Create a leaf for the inserted part + o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Insert both child leaves where they belong: + origIdx := st.key[diffidx] + newIdx := key[diffidx] + p.children[origIdx] = n + p.children[newIdx] = o + st.key = st.key[:diffidx] + + case leafNode: /* Leaf */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Overwriting a key isn't supported, which means that + // the current leaf is expected to be split into 1) an + // optional extension for the common prefix of these 2 + // keys, 2) a fullnode selecting the path on which the + // keys differ, and 3) one leaf for the differentiated + // component of each key. + if diffidx >= len(st.key) { + panic("Trying to insert into existing key") + } + + // Check if the split occurs at the first nibble of the + // chunk. In that case, no prefix extnode is necessary. + // Otherwise, create that + var p *StackTrie + if diffidx == 0 { + // Convert current leaf into a branch + st.nodeType = branchNode + p = st + st.children[0] = nil + } else { + // Convert current node into an ext, + // and insert a child branch node. + st.nodeType = extNode + st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner) + st.children[0].nodeType = branchNode + p = st.children[0] + } + + // Create the two child leaves: one containing the original + // value and another containing the new value. The child leaf + // is hashed directly in order to free up some memory. + origIdx := st.key[diffidx] + p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn) + p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...)) + + newIdx := key[diffidx] + p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn) + + // Finally, cut off the key part that has been passed + // over to the children. + st.key = st.key[:diffidx] + st.val = nil + + case emptyNode: /* Empty */ + st.nodeType = leafNode + st.key = key + st.val = value + + case hashedNode: + panic("trying to insert into hash") + + default: + panic("invalid type") + } +} + +// hash converts st into a 'hashedNode', if possible. Possible outcomes: +// +// 1. The rlp-encoded value was >= 32 bytes: +// - Then the 32-byte `hash` will be accessible in `st.val`. +// - And the 'st.type' will be 'hashedNode' +// +// 2. The rlp-encoded value was < 32 bytes +// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'. +// - And the 'st.type' will be 'hashedNode' AGAIN +// +// This method also sets 'st.type' to hashedNode, and clears 'st.key'. +func (st *StackTrie) hash(path []byte) { + h := newHasher(false) + defer returnHasherToPool(h) + + st.hashRec(h, path) +} + +func (st *StackTrie) hashRec(hasher *hasher, path []byte) { + // The switch below sets this to the RLP-encoding of this node. + var encodedNode []byte + + switch st.nodeType { + case hashedNode: + return + + case emptyNode: + st.val = emptyRoot.Bytes() + st.key = st.key[:0] + st.nodeType = hashedNode + return + + case branchNode: + var nodes FullNode + for i, child := range st.children { + if child == nil { + nodes.Children[i] = nilValueNode + continue + } + child.hashRec(hasher, append(path, byte(i))) + if len(child.val) < 32 { + nodes.Children[i] = rawNode(child.val) + } else { + nodes.Children[i] = HashNode(child.val) + } + + // Release child back to pool. + st.children[i] = nil + returnToPool(child) + } + + nodes.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + case extNode: + st.children[0].hashRec(hasher, append(path, st.key...)) + + n := ShortNode{Key: hexToCompact(st.key)} + if len(st.children[0].val) < 32 { + n.Val = rawNode(st.children[0].val) + } else { + n.Val = HashNode(st.children[0].val) + } + + n.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + // Release child back to pool. + returnToPool(st.children[0]) + st.children[0] = nil + + case leafNode: + st.key = append(st.key, byte(16)) + n := ShortNode{Key: hexToCompact(st.key), Val: ValueNode(st.val)} + + n.encode(hasher.encbuf) + encodedNode = hasher.encodedBytes() + + default: + panic("invalid node type") + } + + st.nodeType = hashedNode + st.key = st.key[:0] + if len(encodedNode) < 32 { + st.val = common.CopyBytes(encodedNode) + return + } + + // Write the hash to the 'val'. We allocate a new val here to not mutate + // input values + st.val = hasher.hashData(encodedNode) + if st.writeFn != nil { + st.writeFn(st.owner, path, common.BytesToHash(st.val), encodedNode) + } +} + +// Hash returns the hash of the current node. +func (st *StackTrie) Hash() (h common.Hash) { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher, nil) + if len(st.val) == 32 { + copy(h[:], st.val) + return h + } + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed, and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + return h +} + +// Commit will firstly hash the entire trie if it's still not hashed +// and then commit all nodes to the associated database. Actually most +// of the trie nodes MAY have been committed already. The main purpose +// here is to commit the root node. +// +// The associated database is expected, otherwise the whole commit +// functionality should be disabled. +func (st *StackTrie) Commit() (h common.Hash, err error) { + if st.writeFn == nil { + return common.Hash{}, ErrCommitDisabled + } + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + st.hashRec(hasher, nil) + if len(st.val) == 32 { + copy(h[:], st.val) + return h, nil + } + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed (and committed), and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing+commit. + hasher.sha.Reset() + hasher.sha.Write(st.val) + hasher.sha.Read(h[:]) + + st.writeFn(st.owner, nil, h, st.val) + return h, nil +} diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go new file mode 100644 index 0000000000..dd5206c87c --- /dev/null +++ b/trie/stacktrie_test.go @@ -0,0 +1,413 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + "math/big" + "testing" + + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" + "github.com/tomochain/tomochain/crypto" +) + +func TestStackTrieInsertAndHash(t *testing.T) { + type KeyValueHash struct { + K string // Hex string for key. + V string // Value, directly converted to bytes. + H string // Expected root hash after insert of (K, V) to an existing trie. + } + tests := [][]KeyValueHash{ + { // {0:0, 7:0, f:0} + {"00", "v_______________________0___0", "5cb26357b95bb9af08475be00243ceb68ade0b66b5cd816b0c18a18c612d2d21"}, + {"70", "v_______________________0___1", "8ff64309574f7a437a7ad1628e690eb7663cfde10676f8a904a8c8291dbc1603"}, + {"f0", "v_______________________0___2", "9e3a01bd8d43efb8e9d4b5506648150b8e3ed1caea596f84ee28e01a72635470"}, + }, + { // {1:0cc, e:{1:fc, e:fc}} + {"10cc", "v_______________________1___0", "233e9b257843f3dfdb1cce6676cdaf9e595ac96ee1b55031434d852bc7ac9185"}, + {"e1fc", "v_______________________1___1", "39c5e908ae83d0c78520c7c7bda0b3782daf594700e44546e93def8f049cca95"}, + {"eefc", "v_______________________1___2", "d789567559fd76fe5b7d9cc42f3750f942502ac1c7f2a466e2f690ec4b6c2a7c"}, + }, + { // {b:{a:ac, b:ac}, d:acc} + {"baac", "v_______________________2___0", "8be1c86ba7ec4c61e14c1a9b75055e0464c2633ae66a055a24e75450156a5d42"}, + {"bbac", "v_______________________2___1", "8495159b9895a7d88d973171d737c0aace6fe6ac02a4769fff1bc43bcccce4cc"}, + {"dacc", "v_______________________2___2", "9bcfc5b220a27328deb9dc6ee2e3d46c9ebc9c69e78acda1fa2c7040602c63ca"}, + }, + { // {0:0cccc, 2:456{0:0, 2:2} + {"00cccc", "v_______________________3___0", "e57dc2785b99ce9205080cb41b32ebea7ac3e158952b44c87d186e6d190a6530"}, + {"245600", "v_______________________3___1", "0335354adbd360a45c1871a842452287721b64b4234dfe08760b243523c998db"}, + {"245622", "v_______________________3___2", "9e6832db0dca2b5cf81c0e0727bfde6afc39d5de33e5720bccacc183c162104e"}, + }, + { // {1:4567{1:1c, 3:3c}, 3:0cccccc} + {"1456711c", "v_______________________4___0", "f2389e78d98fed99f3e63d6d1623c1d4d9e8c91cb1d585de81fbc7c0e60d3529"}, + {"1456733c", "v_______________________4___1", "101189b3fab852be97a0120c03d95eefcf984d3ed639f2328527de6def55a9c0"}, + {"30cccccc", "v_______________________4___2", "3780ce111f98d15751dfde1eb21080efc7d3914b429e5c84c64db637c55405b3"}, + }, + { // 8800{1:f, 2:e, 3:d} + {"88001f", "v_______________________5___0", "e817db50d84f341d443c6f6593cafda093fc85e773a762421d47daa6ac993bd5"}, + {"88002e", "v_______________________5___1", "d6e3e6047bdc110edd296a4d63c030aec451bee9d8075bc5a198eee8cda34f68"}, + {"88003d", "v_______________________5___2", "b6bdf8298c703342188e5f7f84921a402042d0e5fb059969dd53a6b6b1fb989e"}, + }, + { // 0{1:fc, 2:ec, 4:dc} + {"01fc", "v_______________________6___0", "693268f2ca80d32b015f61cd2c4dba5a47a6b52a14c34f8e6945fad684e7a0d5"}, + {"02ec", "v_______________________6___1", "e24ddd44469310c2b785a2044618874bf486d2f7822603a9b8dce58d6524d5de"}, + {"04dc", "v_______________________6___2", "33fc259629187bbe54b92f82f0cd8083b91a12e41a9456b84fc155321e334db7"}, + }, + { // f{0:fccc, f:ff{0:f, f:f}} + {"f0fccc", "v_______________________7___0", "b0966b5aa469a3e292bc5fcfa6c396ae7a657255eef552ea7e12f996de795b90"}, + {"ffff0f", "v_______________________7___1", "3b1ca154ec2a3d96d8d77bddef0abfe40a53a64eb03cecf78da9ec43799fa3d0"}, + {"ffffff", "v_______________________7___2", "e75463041f1be8252781be0ace579a44ea4387bf5b2739f4607af676f7719678"}, + }, + { // ff{0:f{0:f, f:f}, f:fcc} + {"ff0f0f", "v_______________________8___0", "0928af9b14718ec8262ab89df430f1e5fbf66fac0fed037aff2b6767ae8c8684"}, + {"ff0fff", "v_______________________8___1", "d870f4d3ce26b0bf86912810a1960693630c20a48ba56be0ad04bc3e9ddb01e6"}, + {"ffffcc", "v_______________________8___2", "4239f10dd9d9915ecf2e047d6a576bdc1733ed77a30830f1bf29deaf7d8e966f"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"123f", "x___________________________2", "1164d7299964e74ac40d761f9189b2a3987fae959800d0f7e29d3aaf3eae9e15"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"124a", "x___________________________2", "661a96a669869d76b7231380da0649d013301425fbea9d5c5fae6405aa31cfce"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"13aa", "x___________________________2", "6590120e1fd3ffd1a90e8de5bb10750b61079bb0776cca4414dd79a24e4d4356"}, + }, + { + {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"}, + {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"}, + {"2aaa", "x___________________________2", "f869b40e0c55eace1918332ef91563616fbf0755e2b946119679f7ef8e44b514"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"1234fa", "x___________________________2", "4f4e368ab367090d5bc3dbf25f7729f8bd60df84de309b4633a6b69ab66142c0"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"1235aa", "x___________________________2", "21840121d11a91ac8bbad9a5d06af902a5c8d56a47b85600ba813814b7bfcb9b"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"124aaa", "x___________________________2", "ea4040ddf6ae3fbd1524bdec19c0ab1581015996262006632027fa5cf21e441e"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"13aaaa", "x___________________________2", "e4beb66c67e44f2dd8ba36036e45a44ff68f8d52942472b1911a45f886a34507"}, + }, + { + {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"}, + {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"}, + {"2aaaaa", "x___________________________2", "5f5989b820ff5d76b7d49e77bb64f26602294f6c42a1a3becc669cd9e0dc8ec9"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"1234fa", "x___________________________3", "65bb3aafea8121111d693ffe34881c14d27b128fd113fa120961f251fe28428d"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"1235aa", "x___________________________3", "f670e4d2547c533c5f21e0045442e2ecb733f347ad6d29ef36e0f5ba31bb11a8"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"124aaa", "x___________________________3", "c17464123050a9a6f29b5574bb2f92f6d305c1794976b475b7fb0316b6335598"}, + }, + { + {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"}, + {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"}, + {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"}, + {"13aaaa", "x___________________________3", "aa8301be8cb52ea5cd249f5feb79fb4315ee8de2140c604033f4b3fff78f0105"}, + }, + { + {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"}, + {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"}, + {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"}, + {"123f", "x___________________________3", "80f7bad1893ca57e3443bb3305a517723a74d3ba831bcaca22a170645eb7aafb"}, + }, + { + {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"}, + {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"}, + {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"}, + {"124a", "x___________________________3", "383bc1bb4f019e6bc4da3751509ea709b58dd1ac46081670834bae072f3e9557"}, + }, + { + {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"}, + {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"}, + {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"}, + {"13aa", "x___________________________3", "ff0dc70ce2e5db90ee42a4c2ad12139596b890e90eb4e16526ab38fa465b35cf"}, + }, + } + st := NewStackTrie(nil) + for i, test := range tests { + // The StackTrie does not allow Insert(), Hash(), Insert(), ... + // so we will create new trie for every sequence length of inserts. + for l := 1; l <= len(test); l++ { + st.Reset() + for j := 0; j < l; j++ { + kv := &test[j] + if err := st.Update(common.FromHex(kv.K), []byte(kv.V)); err != nil { + t.Fatal(err) + } + } + expected := common.HexToHash(test[l-1].H) + if h := st.Hash(); h != expected { + t.Errorf("%d(%d): root hash mismatch: %x, expected %x", i, l, h, expected) + } + } + } +} + +func TestSizeBug(t *testing.T) { + st := NewStackTrie(nil) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + + nt.Update(leaf, value) + st.Update(leaf, value) + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func TestEmptyBug(t *testing.T) { + st := NewStackTrie(nil) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + kvs := []struct { + K string + V string + }{ + {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "9496f4ec2bf9dab484cac6be589e8417d84781be08"}, + {K: "40edb63a35fcf86c08022722aa3287cdd36440d671b4918131b2514795fefa9c", V: "01"}, + {K: "b10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", V: "947a30f7736e48d6599356464ba4c150d8da0302ff"}, + {K: "c2575a0e9e593c00f959f8c92f12db2869c3395a3b0502d05e2516446f71f85b", V: "02"}, + } + + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func TestValLength56(t *testing.T) { + st := NewStackTrie(nil) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + kvs := []struct { + K string + V string + }{ + {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"}, + } + + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestUpdateSmallNodes tests a case where the leaves are small (both key and value), +// which causes a lot of node-within-node. This case was found via fuzzing. +func TestUpdateSmallNodes(t *testing.T) { + st := NewStackTrie(nil) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + kvs := []struct { + K string + V string + }{ + {"63303030", "3041"}, // stacktrie.Update + {"65", "3000"}, // stacktrie.Update + } + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestUpdateVariableKeys contains a case which stacktrie fails: when keys of different +// sizes are used, and the second one has the same prefix as the first, then the +// stacktrie fails, since it's unable to 'expand' on an already added leaf. +// For all practical purposes, this is fine, since keys are fixed-size length +// in account and storage tries. +// +// The test is marked as 'skipped', and exists just to have the behaviour documented. +// This case was found via fuzzing. +func TestUpdateVariableKeys(t *testing.T) { + t.SkipNow() + st := NewStackTrie(nil) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + kvs := []struct { + K string + V string + }{ + {"0x33303534636532393561313031676174", "303030"}, + {"0x3330353463653239356131303167617430", "313131"}, + } + for _, kv := range kvs { + nt.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + st.Update(common.FromHex(kv.K), common.FromHex(kv.V)) + } + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestStacktrieNotModifyValues checks that inserting blobs of data into the +// stacktrie does not mutate the blobs +func TestStacktrieNotModifyValues(t *testing.T) { + st := NewStackTrie(nil) + { // Test a very small trie + // Give it the value as a slice with large backing alloc, + // so if the stacktrie tries to append, it won't have to realloc + value := make([]byte, 1, 100) + value[0] = 0x2 + want := common.CopyBytes(value) + st.Update([]byte{0x01}, value) + st.Hash() + if have := value; !bytes.Equal(have, want) { + t.Fatalf("tiny trie: have %#x want %#x", have, want) + } + st = NewStackTrie(nil) + } + // Test with a larger trie + keyB := big.NewInt(1) + keyDelta := big.NewInt(1) + var vals [][]byte + getValue := func(i int) []byte { + if i%2 == 0 { // large + return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) + } else { //small + return big.NewInt(int64(i)).Bytes() + } + } + for i := 0; i < 1000; i++ { + key := common.BigToHash(keyB) + value := getValue(i) + st.Update(key.Bytes(), value) + vals = append(vals, value) + keyB = keyB.Add(keyB, keyDelta) + keyDelta.Add(keyDelta, common.Big1) + } + st.Hash() + for i := 0; i < 1000; i++ { + want := getValue(i) + + have := vals[i] + if !bytes.Equal(have, want) { + t.Fatalf("item %d, have %#x want %#x", i, have, want) + } + } +} + +// TestStacktrieSerialization tests that the stacktrie works well if we +// serialize/unserialize it a lot +func TestStacktrieSerialization(t *testing.T) { + var ( + st = NewStackTrie(nil) + keyB = big.NewInt(1) + keyDelta = big.NewInt(1) + vals [][]byte + keys [][]byte + ) + nt, err := New(emptyRoot, NewDatabase(rawdb.NewMemoryDatabase())) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + getValue := func(i int) []byte { + if i%2 == 0 { // large + return crypto.Keccak256(big.NewInt(int64(i)).Bytes()) + } else { //small + return big.NewInt(int64(i)).Bytes() + } + } + for i := 0; i < 10; i++ { + vals = append(vals, getValue(i)) + keys = append(keys, common.BigToHash(keyB).Bytes()) + keyB = keyB.Add(keyB, keyDelta) + keyDelta.Add(keyDelta, common.Big1) + } + for i, k := range keys { + nt.Update(k, common.CopyBytes(vals[i])) + } + + for i, k := range keys { + blob, err := st.MarshalBinary() + if err != nil { + t.Fatal(err) + } + newSt, err := NewFromBinary(blob, nil) + if err != nil { + t.Fatal(err) + } + st = newSt + st.Update(k, common.CopyBytes(vals[i])) + } + if have, want := st.Hash(), nt.Hash(); have != want { + t.Fatalf("have %#x want %#x", have, want) + } +} diff --git a/trie/sync_test.go b/trie/sync_test.go index b7627054ae..25baa5c67c 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -21,13 +21,14 @@ import ( "testing" "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/ethdb/memorydb" ) // makeTestTrie create a sample test trie to test Node-wise reconstruction. func makeTestTrie() (*Database, *Trie, map[string][]byte) { // Create an empty trie - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) trie, _ := New(common.Hash{}, triedb) // Fill it with some arbitrary data @@ -67,7 +68,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri t.Fatalf("inconsistent trie at %x: %v", root, err) } for key, val := range content { - if have := trie.Get([]byte(key)); !bytes.Equal(have, val) { + if have, _ := trie.Get([]byte(key)); !bytes.Equal(have, val) { t.Errorf("entry %x: content mismatch: have %x, want %x", key, have, val) } } @@ -88,8 +89,8 @@ func checkTrieConsistency(db *Database, root common.Hash) error { // Tests that an empty trie is not scheduled for syncing. func TestEmptySync(t *testing.T) { - dbA := NewDatabase(memorydb.New()) - dbB := NewDatabase(memorydb.New()) + dbA := NewDatabase(rawdb.NewMemoryDatabase()) + dbB := NewDatabase(rawdb.NewMemoryDatabase()) emptyA, _ := New(common.Hash{}, dbA) emptyB, _ := New(emptyRoot, dbB) @@ -110,7 +111,7 @@ func testIterativeSync(t *testing.T, count int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -145,7 +146,7 @@ func TestIterativeDelayedSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -185,7 +186,7 @@ func testIterativeRandomSync(t *testing.T, count int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -228,7 +229,7 @@ func TestIterativeRandomDelayedSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -277,7 +278,7 @@ func TestDuplicateAvoidanceSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) @@ -319,7 +320,7 @@ func TestIncompleteSync(t *testing.T) { srcDb, srcTrie, _ := makeTestTrie() // Create a destination trie and sync with the scheduler - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) diff --git a/trie/trie.go b/trie/trie.go index 589a96186d..9df6e56559 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -82,35 +82,45 @@ func New(root common.Hash, db *Database) (*Trie, error) { return trie, nil } +// Copy returns a copy of Trie. +func (t *Trie) Copy() *Trie { + return &Trie{ + Db: t.Db, + root: t.root, + unhashed: t.unhashed, + } +} + // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // the key after the given start key. func (t *Trie) NodeIterator(start []byte) NodeIterator { return newNodeIterator(t, start) } -// Get returns the value for key stored in the trie. -// The value bytes must not be modified by the caller. -func (t *Trie) Get(key []byte) []byte { - res, err := t.TryGet(key) +// MustGet is a wrapper of Get and will omit any encountered error but just +// print out an error message. +func (t *Trie) MustGet(key []byte) []byte { + res, err := t.Get(key) if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + log.Error("Unhandled trie error in Trie.Get", "err", err) } return res } -// TryGet returns the value for key stored in the trie. +// Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryGet(key []byte) ([]byte, error) { - key = keybytesToHex(key) - value, newroot, didResolve, err := t.tryGet(t.root, key, 0) +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Get(key []byte) ([]byte, error) { + value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0) if err == nil && didResolve { t.root = newroot } return value, err } -func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode Node, didResolve bool, err error) { +func (t *Trie) get(origNode Node, key []byte, pos int) (value []byte, newnode Node, didResolve bool, err error) { switch n := (origNode).(type) { case nil: return nil, nil, false, nil @@ -121,14 +131,14 @@ func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode // key not found in trie return nil, n, false, nil } - value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) + value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { n = n.copy() n.Val = newnode } return value, n, didResolve, err case *FullNode: - value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) + value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { n = n.copy() n.Children[key[pos]] = newnode @@ -139,10 +149,10 @@ func (t *Trie) tryGet(origNode Node, key []byte, pos int) (value []byte, newnode if err != nil { return nil, n, true, err } - value, newnode, _, err := t.tryGet(child, key, pos) + value, newnode, _, err := t.get(child, key, pos) return value, newnode, true, err default: - panic(fmt.Sprintf("%T: invalid Node: %v", origNode, origNode)) + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } } @@ -310,27 +320,28 @@ func (t *Trie) tryGetBestRightKeyAndValue(origNode Node, prefix []byte) (key []b return nil, nil, nil, false, fmt.Errorf("%T: invalid Node: %v", origNode, origNode) } -// Update associates key with value in the trie. Subsequent calls to -// Get will return value. If value has length zero, any existing value -// is deleted from the trie and calls to Get will return nil. -// -// The value bytes must not be modified by the caller while they are -// stored in the trie. -func (t *Trie) Update(key, value []byte) { - if err := t.TryUpdate(key, value); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// MustUpdate is a wrapper of Update and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustUpdate(key, value []byte) { + if err := t.Update(key, value); err != nil { + log.Error("Unhandled trie error in Trie.Update", "err", err) } } -// TryUpdate associates key with value in the trie. Subsequent calls to +// Update associates key with value in the trie. Subsequent calls to // Get will return value. If value has length zero, any existing value // is deleted from the trie and calls to Get will return nil. // // The value bytes must not be modified by the caller while they are // stored in the trie. // -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryUpdate(key, value []byte) error { +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Update(key, value []byte) error { + return t.update(key, value) +} + +func (t *Trie) update(key, value []byte) error { t.unhashed++ k := keybytesToHex(key) if len(value) != 0 { @@ -418,16 +429,19 @@ func (t *Trie) insert(n Node, prefix, key []byte, value Node) (bool, Node, error } } -// Delete removes any existing value for key from the trie. -func (t *Trie) Delete(key []byte) { - if err := t.TryDelete(key); err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) +// MustDelete is a wrapper of Delete and will omit any encountered error but +// just print out an error message. +func (t *Trie) MustDelete(key []byte) { + if err := t.Delete(key); err != nil { + log.Error("Unhandled trie error in Trie.Delete", "err", err) } } -// TryDelete removes any existing value for key from the trie. -// If a Node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryDelete(key []byte) error { +// Delete removes any existing value for key from the trie. +// +// If the requested node is not present in trie, no error will be returned. +// If the trie is corrupted, a MissingNodeError is returned. +func (t *Trie) Delete(key []byte) error { t.unhashed++ k := keybytesToHex(key) _, n, err := t.delete(t.root, nil, k) @@ -462,8 +476,8 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { switch child := child.(type) { case *ShortNode: // Deleting from the subtrie reduced it to another - // short Node. Merge the nodes to avoid creating a - // ShortNode{..., ShortNode{...}}. Use concat (which + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. @@ -481,10 +495,18 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { n.flags = t.newFlag() n.Children[key[0]] = nn + // Because n is a full node, it must've contained at least two children + // before the delete operation. If the new child value is non-nil, n still + // has at least two children after the deletion, and cannot be reduced to + // a short node. + if nn != nil { + return true, n, nil + } + // Reduction: // Check how many non-nil entries are left after deleting and - // reduce the full Node to a short Node if only one entry is + // reduce the full node to a short node if only one entry is // left. Since n must've contained at least two children - // before deletion (otherwise it would not be a full Node) n + // before deletion (otherwise it would not be a full node) n // can never be reduced to nil. // // When the loop is done, pos contains the index of the single @@ -503,10 +525,10 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { } if pos >= 0 { if pos != 16 { - // If the remaining entry is a short Node, it replaces + // If the remaining entry is a short node, it replaces // n and its key gets the missing nibble tacked to the // front. This avoids creating an invalid - // ShortNode{..., ShortNode{...}}. Since the entry + // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. cnode, err := t.resolve(n.Children[pos], prefix) @@ -518,7 +540,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { return true, &ShortNode{k, cnode.Val, t.newFlag()}, nil } } - // Otherwise, n is replaced by a one-nibble short Node + // Otherwise, n is replaced by a one-nibble short node // containing the child. return true, &ShortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil } @@ -533,7 +555,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { case HashNode: // We've hit a part of the trie that isn't loaded yet. Load - // the Node and delete from it. This leaves all child nodes on + // the node and delete from it. This leaves all child nodes on // the path to the value in the trie. rn, err := t.resolveHash(n, prefix) if err != nil { @@ -546,7 +568,7 @@ func (t *Trie) delete(n Node, prefix, key []byte) (bool, Node, error) { return true, nn, nil default: - panic(fmt.Sprintf("%T: invalid Node: %v (%v)", n, n, key)) + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) } } @@ -637,3 +659,9 @@ func (t *Trie) hashRoot(db *Database) (Node, Node, error) { t.unhashed = 0 return hashed, cached, nil } + +// Reset drops the referenced root node and cleans all internal state. +func (t *Trie) Reset() { + t.root = nil + t.unhashed = 0 +} diff --git a/trie/trie_test.go b/trie/trie_test.go index 8087a4a8a9..fdfcf4858b 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -29,10 +29,11 @@ import ( "testing/quick" "github.com/davecgh/go-spew/spew" + "github.com/tomochain/tomochain/common" + "github.com/tomochain/tomochain/core/rawdb" "github.com/tomochain/tomochain/crypto" "github.com/tomochain/tomochain/ethdb/leveldb" - "github.com/tomochain/tomochain/ethdb/memorydb" "github.com/tomochain/tomochain/rlp" ) @@ -43,7 +44,7 @@ func init() { // Used for testing func newEmpty() *Trie { - trie, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) return trie } @@ -61,13 +62,13 @@ func TestNull(t *testing.T) { key := make([]byte, 32) value := []byte("test") trie.Update(key, value) - if !bytes.Equal(trie.Get(key), value) { + if !bytes.Equal(trie.MustGet(key), value) { t.Fatal("wrong value") } } func TestMissingRoot(t *testing.T) { - trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(memorydb.New())) + trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(rawdb.NewMemoryDatabase())) if trie != nil { t.Error("New returned non-nil trie for invalid root") } @@ -80,7 +81,7 @@ func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) } func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } func testMissingNode(t *testing.T, memonly bool) { - diskdb := memorydb.New() + diskdb := rawdb.NewMemoryDatabase() triedb := NewDatabase(diskdb) trie, _ := New(common.Hash{}, triedb) @@ -92,27 +93,27 @@ func testMissingNode(t *testing.T, memonly bool) { } trie, _ = New(root, triedb) - _, err := trie.TryGet([]byte("120000")) + _, err := trie.Get([]byte("120000")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("120099")) + _, err = trie.Get([]byte("120099")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("123456")) + _, err = trie.Get([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) + err = trie.Update([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryDelete([]byte("123456")) + err = trie.Delete([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -125,27 +126,27 @@ func testMissingNode(t *testing.T, memonly bool) { } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("120000")) + _, err = trie.Get([]byte("120000")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("120099")) + _, err = trie.Get([]byte("120099")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(root, triedb) - _, err = trie.TryGet([]byte("123456")) + _, err = trie.Get([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) + err = trie.Update([]byte("120099"), []byte("zxcv")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } trie, _ = New(root, triedb) - err = trie.TryDelete([]byte("123456")) + err = trie.Delete([]byte("123456")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } @@ -403,7 +404,7 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value { } func runRandTest(rt randTest) bool { - triedb := NewDatabase(memorydb.New()) + triedb := NewDatabase(rawdb.NewMemoryDatabase()) tr, _ := New(common.Hash{}, triedb) values := make(map[string]string) // tracks content of the trie @@ -419,7 +420,7 @@ func runRandTest(rt randTest) bool { tr.Delete(step.key) delete(values, string(step.key)) case opGet: - v := tr.Get(step.key) + v := tr.MustGet(step.key) want := values[string(step.key)] if string(v) != want { rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) @@ -823,15 +824,11 @@ func tempDB() (string, *Database) { if err != nil { panic(fmt.Sprintf("can't create temporary directory: %v", err)) } - diskdb, err := leveldb.New(dir, 256, 0, "") - if err != nil { - panic(fmt.Sprintf("can't create temporary database: %v", err)) - } - return dir, NewDatabase(diskdb) + return dir, NewDatabase(rawdb.NewMemoryDatabase()) } func getString(trie *Trie, k string) []byte { - return trie.Get([]byte(k)) + return trie.MustGet([]byte(k)) } func updateString(trie *Trie, k, v string) {