Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add shares inclusion proofs #1233

Merged
merged 11 commits into from
Jan 17, 2023
1 change: 1 addition & 0 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ func New(
)

app.QueryRouter().AddRoute(prove.TxInclusionQueryPath, prove.QueryTxInclusionProof)
app.QueryRouter().AddRoute(prove.ShareInclusionQueryPath, prove.QueryShareInclusionProof)

app.mm.RegisterInvariants(&app.CrisisKeeper)
app.mm.RegisterRoutes(app.Router(), app.QueryRouter(), encodingConfig.Amino)
Expand Down
84 changes: 77 additions & 7 deletions app/test/block_size_test.go → app/test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ import (
"testing"
"time"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
"github.com/celestiaorg/celestia-app/testutil/blobfactory"
blobtypes "github.com/celestiaorg/celestia-app/x/blob/types"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
cosmosnet "github.com/cosmos/cosmos-sdk/testutil/network"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/suite"

"github.com/celestiaorg/celestia-app/app"
"github.com/celestiaorg/celestia-app/app/encoding"
"github.com/celestiaorg/celestia-app/pkg/appconsts"
"github.com/celestiaorg/celestia-app/testutil/blobfactory"
"github.com/celestiaorg/celestia-app/pkg/prove"
"github.com/celestiaorg/celestia-app/testutil/network"
"github.com/celestiaorg/celestia-app/x/blob"
blobtypes "github.com/celestiaorg/celestia-app/x/blob/types"

sdk "github.com/cosmos/cosmos-sdk/types"
abci "github.com/tendermint/tendermint/abci/types"
tmrand "github.com/tendermint/tendermint/libs/rand"
rpctypes "github.com/tendermint/tendermint/rpc/core/types"
Expand Down Expand Up @@ -284,3 +284,73 @@ func queryTx(clientCtx client.Context, hashHexStr string, prove bool) (*rpctypes

return node.Tx(context.Background(), hash, prove)
}

rootulp marked this conversation as resolved.
Show resolved Hide resolved
func (s *IntegrationTestSuite) TestSharesInclusionProof() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func (s *IntegrationTestSuite) TestSharesInclusionProof() {
func (s *IntegrationTestSuite) TestShareInclusionProof() {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

require := s.Require()
val := s.network.Validators[0]

// generate 100 randomly sized txs (max size == 100kb)
txs := blobfactory.RandBlobTxsWithAccounts(
s.cfg.TxConfig.TxEncoder(),
s.kr,
val.ClientCtx.GRPCClient,
100000,
true,
s.cfg.ChainID,
s.accounts[:20],
)

hashes := make([]string, len(txs))

for i, tx := range txs {
res, err := val.ClientCtx.BroadcastTxSync(tx)
require.NoError(err)
require.Equal(abci.CodeTypeOK, res.Code)
hashes[i] = res.TxHash
}

// wait a few blocks to clear the txs
for i := 0; i < 20; i++ {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[non-blocking] probably a worthwhile testing util helper

func WaitForNBlocks(num int) {
  for i := 0; i < n; i++ {
    require.NoError(s.network.WaitForNextBlock())
  }
}

Copy link
Member Author

@rach-id rach-id Jan 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue #1238 since we're using this in many tests. So, it makes more sense to have a separate PR making the changes throughout the whole repo. Thanks a lot for pointing it out 👍

require.NoError(s.network.WaitForNextBlock())
}

for _, hash := range hashes {
txResp, err := queryTx(val.ClientCtx, hash, true)
require.NoError(err)
require.Equal(abci.CodeTypeOK, txResp.TxResult.Code)

// verify that the transaction inclusion proof is valid
require.True(txResp.Proof.VerifyProof())

// get the transaction shares
node, err := val.ClientCtx.GetNode()
require.NoError(err)
blockRes, err := node.Block(context.Background(), &txResp.Height)
require.NoError(err)
beginTxShare, endTxShare, err := prove.TxSharePosition(blockRes.Block.Txs, uint64(txResp.Index))
require.NoError(err)

txProof, err := node.ProveShares(
context.Background(),
uint64(txResp.Height),
beginTxShare,
endTxShare,
)
require.NoError(err)
require.NoError(txProof.Validate(blockRes.Block.DataHash))

// get the message shares
beginMsgShare, endMsgShare, err := prove.MsgSharesPosition(blockRes.Block.Txs[txResp.Index])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] rename message to blob so something like:

		// get the blob shares
		beginBlobShare, endBlobShare, err := prove.BlobSharePositions(blockRes.Block.Txs[txResp.Index])

also consider renaming MsgSharesPosition => BlobSharePositions because the function appears to return two positions. Alternative naming proposal: BlobShareRange

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

require.NoError(err)

// verify the message shares proof
msgProof, err := node.ProveShares(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] rename message to blob

		// verify the blob shares proof
		blobProof, err := node.ProveShares(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
msgProof, err := node.ProveShares(
blobProof, err := node.ProveShares(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context.Background(),
uint64(txResp.Height),
beginMsgShare,
endMsgShare,
)
require.NoError(err)
require.NoError(msgProof.Validate(blockRes.Block.DataHash))
}
}
154 changes: 150 additions & 4 deletions pkg/prove/proof.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
package prove

import (
"bytes"
"errors"
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"

"github.com/celestiaorg/celestia-app/app/encoding"
"github.com/celestiaorg/celestia-app/pkg/appconsts"
"github.com/celestiaorg/celestia-app/pkg/da"
"github.com/celestiaorg/celestia-app/pkg/shares"
"github.com/celestiaorg/celestia-app/pkg/wrapper"
blobmodule "github.com/celestiaorg/celestia-app/x/blob"
blobtypes "github.com/celestiaorg/celestia-app/x/blob/types"
"github.com/celestiaorg/nmt/namespace"
"github.com/celestiaorg/rsmt2d"
"github.com/tendermint/tendermint/crypto/merkle"
tmbytes "github.com/tendermint/tendermint/libs/bytes"
tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
"github.com/tendermint/tendermint/types"
)

// TxInclusion uses the provided block data to progressively generate rows
// of a data square, and then using those shares to creates nmt inclusion proofs
// of a data square, and then using those shares to creates nmt inclusion proofs.
// It is possible that a transaction spans more than one row. In that case, we
// have to return more than one proof.
func TxInclusion(codec rsmt2d.Codec, data types.Data, txIndex uint64) (types.TxProof, error) {
// calculate the index of the shares that contain the tx
startPos, endPos, err := txSharePosition(data.Txs, txIndex)
startPos, endPos, err := TxSharePosition(data.Txs, txIndex)
if err != nil {
return types.TxProof{}, err
}
Expand Down Expand Up @@ -84,10 +94,10 @@ func TxInclusion(codec rsmt2d.Codec, data types.Data, txIndex uint64) (types.TxP
}, nil
}

// txSharePosition returns the start and end positions for the shares that
// TxSharePosition returns the start and end positions for the shares that
// include a given txIndex. Returns an error if index is greater than the length
// of txs.
func txSharePosition(txs types.Txs, txIndex uint64) (startSharePos, endSharePos uint64, err error) {
func TxSharePosition(txs types.Txs, txIndex uint64) (startSharePos, endSharePos uint64, err error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional][can be a FLUP issue] we may consider renaming to TxSharePositions because this function returns multiple positions. Alternative naming proposal: TxShareRange

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does FLUP mean? :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FLUP: Follow-Up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sure #1241, thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this func was original done in core before our rather massive refactors to the shares pkg, so just noting mentally that we really need to refactor this one as well. Surprised this works tbh #703

if txIndex >= uint64(len(txs)) {
return startSharePos, endSharePos, errors.New("transaction index is greater than the number of txs")
}
Expand All @@ -107,6 +117,49 @@ func txSharePosition(txs types.Txs, txIndex uint64) (startSharePos, endSharePos
return startSharePos, endSharePos, nil
}

// MsgSharesPosition returns the start and end positions for the shares
// where a given message, referenced by its wrapped pfb transaction, was published at.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// where a given message, referenced by its wrapped pfb transaction, was published at.
// where a given blob, referenced by its wrapped PFB transaction, was published at.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Note: only supports transactions containing a single message
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Note: only supports transactions containing a single message
// Note: only supports transactions containing a single blob

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func MsgSharesPosition(tx types.Tx) (beginShare uint64, endShare uint64, err error) {
unwrappedTx, isMalleated := types.UnmarshalIndexWrapper(tx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] consider avoiding the malleated terminology given there is no more malleation

Suggested change
unwrappedTx, isMalleated := types.UnmarshalIndexWrapper(tx)
indexWrappedTx, isIndexWrapped := types.UnmarshalIndexWrapper(tx)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if !isMalleated {
return beginShare, endShare, fmt.Errorf("not a malleated tx")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return beginShare, endShare, fmt.Errorf("not a malleated tx")
return beginShare, endShare, fmt.Errorf("not a index wrapped tx")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

encCfg := encoding.MakeConfig(blobmodule.AppModuleBasic{})
decoder := encCfg.TxConfig.TxDecoder()

decodedTx, err := decoder(unwrappedTx.Tx)
if err != nil {
return beginShare, endShare, err
}

if len(decodedTx.GetMsgs()) == 0 {
Copy link
Member Author

@rach-id rach-id Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[note to reviewers]
This helper method only proves one message per blob. We add multiple messages later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blocking question]

This helper method only proves one message per blob. We add multiple messages later

should we check for or something similar then?

Suggested change
if len(decodedTx.GetMsgs()) == 0 {
if len(decodedTx.GetMsgs()) != 1 {

note, preview does not also change the error message below which we'll also have to change should this be accepted

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a subsequent check:

	if len(decodedTx.GetMsgs()) > 1 {
		return beginShare, endShare, fmt.Errorf("PayForBlob contains multiple messages and this is not currently supported")
	}

that should catch that

return beginShare, endShare, fmt.Errorf("pfb contains no messages")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit][proposal] to explicitly spell out PFB in error messages because they may be seen and reported by users who are not aware of Celestia specific acronyms.

Suggested change
return beginShare, endShare, fmt.Errorf("pfb contains no messages")
return beginShare, endShare, fmt.Errorf("PayForBlob contains no messages")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

if len(decodedTx.GetMsgs()) > 1 {
return beginShare, endShare, fmt.Errorf("pfb containing multiple messages. not currently supported")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return beginShare, endShare, fmt.Errorf("pfb containing multiple messages. not currently supported")
return beginShare, endShare, fmt.Errorf("PayForBlob contains multiple messages and this is not currently supported")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

if sdk.MsgTypeURL(decodedTx.GetMsgs()[0]) != blobtypes.URLMsgPayForBlob {
return beginShare, endShare, fmt.Errorf("transaction is not pfb")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] is it more slightly more accurate to say the message inside this transaction is not a MsgPayForBlob?

Suggested change
return beginShare, endShare, fmt.Errorf("transaction is not pfb")
return beginShare, endShare, fmt.Errorf("msg is not a MsgPayForBlob")

The reason I ask is because the above conditionals assume we're only dealing with a PFB tx

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

pfb, ok := decodedTx.GetMsgs()[0].(*blobtypes.MsgPayForBlob)
if !ok {
return beginShare, endShare, fmt.Errorf("unable to decode pfb")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return beginShare, endShare, fmt.Errorf("unable to decode pfb")
return beginShare, endShare, fmt.Errorf("unable to decode PayForBlob")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

if err = pfb.ValidateBasic(); err != nil {
return beginShare, endShare, err
}

beginShare = uint64(unwrappedTx.ShareIndexes[0])
sharesUsed := shares.SparseSharesNeeded(pfb.BlobSize)
return beginShare, beginShare + uint64(sharesUsed) - 1, nil
}

// txShareIndex returns the index of the compact share that would contain
// transactions with totalTxLen
func txShareIndex(totalTxLen int) (index uint64) {
Expand Down Expand Up @@ -193,3 +246,96 @@ func splitIntoRows(squareSize uint64, s []shares.Share) [][]shares.Share {
}
return rows
}

// SharesInclusion generates an nmt inclusion proof for a set of shares to the data root.
// expects the shares range to be pre-validated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] can you please clarify what you mean by pre-validated?

Suggested change
// expects the shares range to be pre-validated.
// Expects the share range to be pre-validated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Note: only supports inclusion proofs for shares belonging to the same namespace.
func SharesInclusion(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] this function name seems a bit ambiguous. Do these naming proposals convey it's behavior any clearer:

  • GenerateShareInclusionProof
  • NewShareInclusionProof

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allRawShares []shares.Share,
squareSize uint64,
namespaceID namespace.ID,
startShare uint64,
endShare uint64,
) (types.SharesProof, error) {
startRow := startShare / squareSize
endRow := endShare / squareSize
startLeaf := startShare % squareSize
endLeaf := endShare % squareSize

eds, err := da.ExtendShares(squareSize, shares.ToBytes(allRawShares))
if err != nil {
return types.SharesProof{}, err
}

edsRowRoots := eds.RowRoots()

// create the binary merkle inclusion proof, for all the square rows, to the data root
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// create the binary merkle inclusion proof, for all the square rows, to the data root
// create the binary merkle inclusion proof for all the square rows to the data root

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_, allProofs := merkle.ProofsFromByteSlices(append(edsRowRoots, eds.ColRoots()...))
rowsProofs := make([]*merkle.Proof, endRow-startRow+1)
rowsRoots := make([]tmbytes.HexBytes, endRow-startRow+1)
for i := startRow; i <= endRow; i++ {
rowsProofs[i-startRow] = allProofs[i]
rowsRoots[i-startRow] = edsRowRoots[i]
}

// get the extended rows containing the shares.
rows := make([][]shares.Share, endRow-startRow+1)
for i := startRow; i <= endRow; i++ {
rows[i-startRow] = shares.FromBytes(eds.Row(uint(i)))
}

var sharesProofs []*tmproto.NMTProof //nolint:prealloc
var rawShares [][]byte
for i, row := range rows {
// create an nmt to generate a proof.
// we have to re-create the tree as the eds one is not accessible.
tree := wrapper.NewErasuredNamespacedMerkleTree(squareSize, uint(i))
for _, share := range row {
tree.Push(
share,
)
}

startLeafPos := startLeaf
endLeafPos := endLeaf

// if this is not the first row, then start with the first leaf
if i > 0 {
startLeafPos = 0
}
// if this is not the last row, then select for the rest of the row
if i != (len(rows) - 1) {
endLeafPos = squareSize - 1
}

rawShares = append(rawShares, shares.ToBytes(row[startLeafPos:endLeafPos+1])...)
proof, err := tree.Tree().ProveRange(int(startLeafPos), int(endLeafPos+1))
if err != nil {
return types.SharesProof{}, err
}

sharesProofs = append(sharesProofs, &tmproto.NMTProof{
Start: int32(proof.Start()),
End: int32(proof.End()),
Nodes: proof.Nodes(),
LeafHash: proof.LeafHash(),
})

// make sure that the generated root is the same as the eds row root.
if !bytes.Equal(rowsRoots[i].Bytes(), tree.Root()) {
return types.SharesProof{}, errors.New("eds row root is different than tree root")
}
}

return types.SharesProof{
RowsProof: types.RowsProof{
RowsRoots: rowsRoots,
Proofs: rowsProofs,
StartRow: uint32(startRow),
EndRow: uint32(endRow),
},
Data: rawShares,
SharesProofs: sharesProofs,
NamespaceID: namespaceID,
}, nil
}
Loading