diff --git a/geth-utils/gethutil/mpt/trie/stacktrie.go b/geth-utils/gethutil/mpt/trie/stacktrie.go index c5c1849402..0b214071f3 100644 --- a/geth-utils/gethutil/mpt/trie/stacktrie.go +++ b/geth-utils/gethutil/mpt/trie/stacktrie.go @@ -243,16 +243,17 @@ func (st *StackTrie) insert(key, value []byte) { break } } + // Add new child if st.children[idx] == nil { - st.children[idx] = stackTrieFromPool(st.db) - st.children[idx].keyOffset = st.keyOffset + 1 + st.children[idx] = newLeaf(st.keyOffset+1, key, value, st.db) + } else { + st.children[idx].insert(key, value) } - st.children[idx].insert(key, value) + case extNode: /* Ext */ // Compare both key chunks and see where they differ diffidx := st.getDiffIndex(key) - // Check if chunks are identical. If so, recurse into // the child node. Otherwise, the key has to be split // into 1) an optional common prefix, 2) the fullnode @@ -551,50 +552,70 @@ func (st *StackTrie) Commit() (common.Hash, error) { return common.BytesToHash(st.val), nil } -func (st *StackTrie) getNodeFromBranchRLP(branch []byte, ind byte) []byte { - start := 2 // when branch[0] == 248 - if branch[0] == 249 { - start = 3 - } - - i := 0 - insideInd := -1 - cInd := byte(0) - for { - if start+i == len(branch)-1 { // -1 because of the last 128 (branch value) - return []byte{0} - } - b := branch[start+i] - if insideInd == -1 && b == 128 { - if cInd == ind { +const RLP_SHORT_STR_FLAG = 128 +const RLP_SHORT_LIST_FLAG = 192 +const RLP_LONG_LIST_FLAG = 248 +const LEN_OF_HASH = 32 + +// Note: +// In RLP encoding, if the value is between [0x80, 0xb7] ([128, 183]), +// it means following data is a short string (0 - 55bytes). +// Which implies if the value is 128, it's an empty string. +func (st *StackTrie) getNodeFromBranchRLP(branch []byte, idx int) []byte { + + start := int(branch[0]) + start_idx := 0 + if start >= RLP_SHORT_LIST_FLAG && start < RLP_LONG_LIST_FLAG { + // In RLP encoding, length in the range of [192 248] is a short list. + // In stack trie, it usually means an extension node and the first byte is nibble + // and that's why we start from 2 + start_idx = 2 + } else if start >= RLP_LONG_LIST_FLAG { + // In RLP encoding, length in the range of [248 ~ ] is a long list. + // The RLP byte minus 248 (branch[0] - 248) is the length in bytes of the length of the payload + // and the payload is right after the length. + // That's why we add 2 here + // e.g. [248 81 128 160 ...] + // `81` is the length of the payload and payload starts from `128` + start_idx = start - RLP_LONG_LIST_FLAG + 2 + } + + // If 1st node is neither 128(empty node) nor 160, it should be a leaf + b := int(branch[start_idx]) + if b != RLP_SHORT_STR_FLAG && b != (RLP_SHORT_STR_FLAG+LEN_OF_HASH) { + return []byte{0} + } + + current_idx := 0 + for i := start_idx; i < len(branch); i++ { + b = int(branch[i]) + switch b { + case RLP_SHORT_STR_FLAG: // 128 + // if the current index is we're looking for, return an empty node directly + if current_idx == idx { return []byte{128} - } else { - cInd += 1 } - } else if insideInd == -1 && b != 128 { - if b == 160 { - if cInd == ind { - return branch[start+i+1 : start+i+1+32] - } - insideInd = 32 - } else { - // non-hashed node - if cInd == ind { - return branch[start+i+1 : start+i+1+int(b)-192] - } - insideInd = int(b) - 192 + current_idx++ + case RLP_SHORT_STR_FLAG + LEN_OF_HASH: // 160 + if current_idx == idx { + return branch[i+1 : i+1+LEN_OF_HASH] } - cInd += 1 - } else { - if insideInd == 1 { - insideInd = -1 - } else { - insideInd-- + // jump to next encoded element + i += LEN_OF_HASH + current_idx++ + default: + if b >= 192 && b < 248 { + length := b - 192 + if current_idx == idx { + return branch[i+1 : i+1+length] + } + i += length + current_idx++ } } - - i++ } + + return []byte{0} } type StackProof struct { @@ -602,6 +623,14 @@ type StackProof struct { proofC [][]byte } +func (sp *StackProof) GetProofS() [][]byte { + return sp.proofS +} + +func (sp *StackProof) GetProofC() [][]byte { + return sp.proofC +} + func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value []byte) (StackProof, error) { proofS, err := st.GetProof(db, indexBuf) if err != nil { @@ -618,6 +647,8 @@ func (st *StackTrie) UpdateAndGetProof(db ethdb.KeyValueReader, indexBuf, value return StackProof{proofS, proofC}, nil } +// We refer to the link below for this function. +// https://github.com/ethereum/go-ethereum/blob/00905f7dc406cfb67f64cd74113777044fb886d8/core/types/hashing.go#L105-L134 func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.DerivableList) ([]StackProof, error) { valueBuf := types.EncodeBufferPool.Get().(*bytes.Buffer) defer types.EncodeBufferPool.Put(valueBuf) @@ -631,25 +662,33 @@ func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.Deri for i := 1; i < list.Len() && i <= 0x7f; i++ { indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) value := types.EncodeForDerive(list, i, valueBuf) - proof, err := st.UpdateAndGetProof(db, indexBuf, value) if err != nil { return nil, err } - proofs = append(proofs, proof) } + + // special case when index is 0 + // rlp.AppendUint64() encodes index 0 to [128] if list.Len() > 0 { indexBuf = rlp.AppendUint64(indexBuf[:0], 0) value := types.EncodeForDerive(list, 0, valueBuf) - // TODO: get proof - st.Update(indexBuf, value) + proof, err := st.UpdateAndGetProof(db, indexBuf, value) + if err != nil { + return nil, err + } + proofs = append(proofs, proof) } + for i := 0x80; i < list.Len(); i++ { indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i)) value := types.EncodeForDerive(list, i, valueBuf) - // TODO: get proof - st.Update(indexBuf, value) + proof, err := st.UpdateAndGetProof(db, indexBuf, value) + if err != nil { + return nil, err + } + proofs = append(proofs, proof) } return proofs, nil @@ -657,7 +696,6 @@ func (st *StackTrie) UpdateAndGetProofs(db ethdb.KeyValueReader, list types.Deri func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, error) { k := KeybytesToHex(key) - if st.nodeType == emptyNode { return [][]byte{}, nil } @@ -682,7 +720,8 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er for i := 0; i < len(k); i++ { if c.nodeType == extNode { nodes = append(nodes, c) - c = st.children[0] + c = c.children[0] + } else if c.nodeType == branchNode { nodes = append(nodes, c) c = c.children[k[i]] @@ -700,11 +739,11 @@ func (st *StackTrie) GetProof(db ethdb.KeyValueReader, key []byte) ([][]byte, er } proof = append(proof, c_rlp) - branchChild := st.getNodeFromBranchRLP(c_rlp, k[i]) + branchChild := st.getNodeFromBranchRLP(c_rlp, int(k[i])) // branchChild is of length 1 when there is no child at this position in the branch // (`branchChild = [128]` in this case), but it is also of length 1 when `c_rlp` is a leaf. - if len(branchChild) == 1 { + if len(branchChild) == 1 && (branchChild[0] == 128 || branchChild[0] == 0) { // no child at this position - 128 is RLP encoding for nil object break } diff --git a/geth-utils/gethutil/mpt/witness/gen_witness_transactions_test.go b/geth-utils/gethutil/mpt/witness/gen_witness_transactions_test.go index 723234f255..962b9461b3 100644 --- a/geth-utils/gethutil/mpt/witness/gen_witness_transactions_test.go +++ b/geth-utils/gethutil/mpt/witness/gen_witness_transactions_test.go @@ -1,6 +1,7 @@ package witness import ( + "bytes" "fmt" "main/gethutil/mpt/trie" "main/gethutil/mpt/types" @@ -8,16 +9,40 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" ) +func makeTransactions(n int) []*types.Transaction { + txs := make([]*types.Transaction, n) + key, _ := crypto.GenerateKey() + signer := types.LatestSigner(params.TestChainConfig) + + for i := range txs { + amount := big.NewInt(int64(i)*10 ^ 9) + gas_price := big.NewInt(300_000) + data := make([]byte, 100) + // randomly assigned for debugging + data[3] = 3 + data[4] = 3 + data[5] = 3 + data[6] = 3 + data[7] = 3 + tx := types.NewTransaction(uint64(i), common.Address{}, amount, 10*10^6, gas_price, data) + signedTx, err := types.SignTx(tx, signer, key) + if err != nil { + panic(err) + } + txs[i] = signedTx + } + return txs +} + /* -TestNonHashedTransactionsStackTrieGetProof inserts 70 transactions into a stacktrie. -For each of the 70 modifications of the trie it asks for a proof - GetProof is called before +transactionsStackTrieInsertionTemplate inserts n transactions into a stacktrie. +For each of the n modifications of the trie it asks for a proof - GetProof is called before and after the modification. The transactions in the trie are not hashed and thus GetProof does not require to query a database to get the preimages. @@ -62,59 +87,84 @@ The first proof element is a branch with children at position 0 (branch B) and 1 The second element is the 16-th transaction. For example, the third byte (16) represents the transaction index. */ -func TestNonHashedTransactionsStackTrieGetProof(t *testing.T) { - txs := make([]*types.Transaction, 70) - key, _ := crypto.GenerateKey() - signer := types.LatestSigner(params.TestChainConfig) - - for i := range txs { - amount := math.BigPow(2, int64(i)) - price := big.NewInt(300000) - data := make([]byte, 100) - tx := types.NewTransaction(uint64(i), common.Address{}, amount, 123457, price, data) - signedTx, err := types.SignTx(tx, signer, key) - if err != nil { - panic(err) - } - txs[i] = signedTx - } +func transactionsStackTrieInsertionTemplate(t *testing.T, n int) { + txs := makeTransactions(n) db := rawdb.NewMemoryDatabase() stackTrie := trie.NewStackTrie(db) - stackTrie.UpdateAndGetProofs(db, types.Transactions(txs)) + proofs, _ := stackTrie.UpdateAndGetProofs(db, types.Transactions(txs)) - fmt.Println("===") + rlp_last_tx, _ := txs[n-1].MarshalBinary() + last_proofC := proofs[len(proofs)-1].GetProofC() + + // Proof of the first tx is appended at the end of the proofs if len(tx) < 0x80 + // That's why we minus 2 here. + if len(txs) > 1 && len(txs) < 256 { + last_proofC = proofs[len(proofs)-2].GetProofC() + } + last_leaf_proof := last_proofC[len(last_proofC)-1] + + if !bytes.Equal(last_leaf_proof, rlp_last_tx) { + fmt.Println("- last_tx ", rlp_last_tx) + fmt.Println("- last_proof ", last_leaf_proof) + t.Fail() + } +} + +func TestStackTrieInsertion_1Tx(t *testing.T) { + // Only one leaf + transactionsStackTrieInsertionTemplate(t, 1) +} + +func TestStackTrieInsertion_2Txs(t *testing.T) { + // One ext. node and one leaf + transactionsStackTrieInsertionTemplate(t, 2) +} + +func TestStackTrieInsertion_3Txs(t *testing.T) { + // One ext. node, one branch and one leaf + transactionsStackTrieInsertionTemplate(t, 3) +} + +func TestStackTrieInsertion_4Txs(t *testing.T) { + // One ext. node, one branch and two leaves + transactionsStackTrieInsertionTemplate(t, 4) +} + +func TestStackTrieInsertion_16Txs(t *testing.T) { + // One ext. node and one branch with full leaves (16 leaves) + transactionsStackTrieInsertionTemplate(t, 16) +} + +func TestStackTrieInsertion_17Txs(t *testing.T) { + // One ext. node, 3 branches and 17 leaves. + // The original ext. node turns into a branch (B1) which has children at position 0 and 1. + // At position 0 of B1, it has a branch with full leaves + // At position 1 of B1, it has a newly leaf + transactionsStackTrieInsertionTemplate(t, 17) +} + +func TestStackTrieInsertion_33Txs(t *testing.T) { + // Follow above test and have one more branch generated + transactionsStackTrieInsertionTemplate(t, 33) +} + +func TestStackTrieInsertion_ManyTxs(t *testing.T) { + // Just randomly picking a large number. + // The cap of block gas limit is 30M, the minimum gas cost of a tx is 21k + // 30M / 21k ~= 1429 + transactionsStackTrieInsertionTemplate(t, 2000) } /* -TestHashedTransactionsStackTrieGetProof inserts 2 transactions into a stacktrie, +batchedTransactionsStackTrieProofTemplate inserts n transactions into a stacktrie, the trie is then hashed (DeriveSha call). -The proof is asked for one of the two transactions. The transactions in the trie are hashed and thus +The proof is asked for one of the n transactions. The transactions in the trie are hashed and thus GetProof requires to query a database to get the preimages. */ -func TestHashedTransactionsStackTrieGetProof(t *testing.T) { - txs := make([]*types.Transaction, 2) - key, _ := crypto.GenerateKey() - signer := types.LatestSigner(params.TestChainConfig) - - for i := range txs { - amount := math.BigPow(2, int64(i)) - price := big.NewInt(300000) - data := make([]byte, 100) - data[3] = 3 - data[4] = 3 - data[5] = 3 - data[6] = 3 - data[7] = 3 - tx := types.NewTransaction(uint64(i), common.Address{}, amount, 123457, price, data) - signedTx, err := types.SignTx(tx, signer, key) - if err != nil { - panic(err) - } - txs[i] = signedTx - } - +func batchedTransactionsStackTrieProofTemplate(n int) { + txs := makeTransactions(n) db := rawdb.NewMemoryDatabase() stackTrie := trie.NewStackTrie(db) @@ -130,7 +180,35 @@ func TestHashedTransactionsStackTrieGetProof(t *testing.T) { return } - fmt.Println(proofS) - + fmt.Println("proofS", proofS) fmt.Println("===") } + +func TestBatchedTxsProof_1Tx(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(1) +} + +func TestBatchedTxsProof_2Txs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(2) +} + +func TestBatchedTxsProof_3Txs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(3) +} +func TestBatchedTxsProof_4Txs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(4) +} + +func TestBatchedTxsProof_16Txs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(16) +} + +func TestBatchedTxsProof_17Txs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(17) +} +func TestBatchedTxsProof_33Txs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(33) +} +func TestBatchedTxsProof_ManyTxs(t *testing.T) { + batchedTransactionsStackTrieProofTemplate(2000) +}