From bca322bc361e62fce7cdaf7966c9ae690fc04de8 Mon Sep 17 00:00:00 2001 From: vuong Date: Thu, 24 Feb 2022 08:55:07 +0700 Subject: [PATCH 1/3] IbcTxMiddleware implements ibc tx handling middleware --- modules/core/middleware/middleware.go | 108 +++++ modules/core/middleware/middleware_test.go | 486 +++++++++++++++++++++ 2 files changed, 594 insertions(+) create mode 100644 modules/core/middleware/middleware.go create mode 100644 modules/core/middleware/middleware_test.go diff --git a/modules/core/middleware/middleware.go b/modules/core/middleware/middleware.go new file mode 100644 index 00000000000..f8a6006e20a --- /dev/null +++ b/modules/core/middleware/middleware.go @@ -0,0 +1,108 @@ +package middleware + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx" + clienttypes "github.com/cosmos/ibc-go/v3/modules/core/02-client/types" + channelkeeper "github.com/cosmos/ibc-go/v3/modules/core/04-channel/keeper" + channeltypes "github.com/cosmos/ibc-go/v3/modules/core/04-channel/types" +) + +var _ tx.Handler = ibcTxHandler{} + +type ibcTxHandler struct { + k channelkeeper.Keeper + next tx.Handler +} + +// IbcTxMiddleware implements ibc tx handling middleware +func IbcTxMiddleware(channelkeeper channelkeeper.Keeper) tx.Middleware { + return func(txh tx.Handler) tx.Handler { + return ibcTxHandler{ + k: channelkeeper, + next: txh, + } + } +} + +func (itxh ibcTxHandler) setIbcTxHandler(ctx context.Context, req tx.Request, simulate bool) error { + sdkCtx := sdk.UnwrapSDKContext(ctx) + // do not run redundancy check on DeliverTx or simulate + if (sdkCtx.IsCheckTx() || sdkCtx.IsReCheckTx()) && !simulate { + redundancies := 0 + packetMsgs := 0 + for _, m := range req.Tx.GetMsgs() { + switch msg := m.(type) { + case *channeltypes.MsgRecvPacket: + if _, found := itxh.k.GetPacketReceipt(sdkCtx, msg.Packet.GetDestPort(), msg.Packet.GetDestChannel(), msg.Packet.GetSequence()); found { + redundancies++ + } + packetMsgs++ + + case *channeltypes.MsgAcknowledgement: + if commitment := itxh.k.GetPacketCommitment(sdkCtx, msg.Packet.GetSourcePort(), msg.Packet.GetSourceChannel(), msg.Packet.GetSequence()); len(commitment) == 0 { + redundancies++ + } + packetMsgs++ + + case *channeltypes.MsgTimeout: + if commitment := itxh.k.GetPacketCommitment(sdkCtx, msg.Packet.GetSourcePort(), msg.Packet.GetSourceChannel(), msg.Packet.GetSequence()); len(commitment) == 0 { + redundancies++ + } + packetMsgs++ + + case *channeltypes.MsgTimeoutOnClose: + if commitment := itxh.k.GetPacketCommitment(sdkCtx, msg.Packet.GetSourcePort(), msg.Packet.GetSourceChannel(), msg.Packet.GetSequence()); len(commitment) == 0 { + redundancies++ + } + packetMsgs++ + + case *clienttypes.MsgUpdateClient: + // do nothing here, as we want to avoid updating clients if it is batched with only redundant messages + + default: + // if the multiMsg tx has a msg that is not a packet msg or update msg, then we will not return error + // regardless of if all packet messages are redundant. This ensures that non-packet messages get processed + // even if they get batched with redundant packet messages. + return nil + } + + } + + // only return error if all packet messages are redundant + if redundancies == packetMsgs && packetMsgs > 0 { + return channeltypes.ErrRedundantTx + } + } + return nil +} + +// CheckTx implements tx.Handler.CheckTx. +func (itxh ibcTxHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) { + err := itxh.setIbcTxHandler(ctx, req, false) + if err != nil { + return tx.Response{}, tx.ResponseCheckTx{}, err + } + + return itxh.next.CheckTx(ctx, req, checkReq) +} + +// DeliverTx implements tx.Handler.DeliverTx. +func (itxh ibcTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) { + err := itxh.setIbcTxHandler(ctx, req, false) + if err != nil { + return tx.Response{}, err + } + return itxh.next.DeliverTx(ctx, req) +} + +// SimulateTx implements tx.Handler.SimulateTx. +func (itxh ibcTxHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) { + err := itxh.setIbcTxHandler(ctx, req, true) + if err != nil { + return tx.Response{}, err + } + return itxh.next.SimulateTx(ctx, req) +} diff --git a/modules/core/middleware/middleware_test.go b/modules/core/middleware/middleware_test.go new file mode 100644 index 00000000000..72d847a8ab6 --- /dev/null +++ b/modules/core/middleware/middleware_test.go @@ -0,0 +1,486 @@ +package middleware_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/auth/middleware" + "github.com/stretchr/testify/suite" + + clienttypes "github.com/cosmos/ibc-go/v3/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v3/modules/core/04-channel/types" + ibctesting "github.com/cosmos/ibc-go/v3/testing" + "github.com/cosmos/ibc-go/v3/testing/mock" +) + +type MiddlewareTestSuite struct { + suite.Suite + + coordinator *ibctesting.Coordinator + + // testing chains used for convenience and readability + chainA *ibctesting.TestChain + chainB *ibctesting.TestChain + + path *ibctesting.Path +} + +// SetupTest creates a coordinator with 2 test chains. +func (suite *MiddlewareTestSuite) SetupTest() { + suite.coordinator = ibctesting.NewCoordinator(suite.T(), 2) + suite.chainA = suite.coordinator.GetChain(ibctesting.GetChainID(1)) + suite.chainB = suite.coordinator.GetChain(ibctesting.GetChainID(2)) + // commit some blocks so that QueryProof returns valid proof (cannot return valid query if height <= 1) + suite.coordinator.CommitNBlocks(suite.chainA, 2) + suite.coordinator.CommitNBlocks(suite.chainB, 2) + suite.path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(suite.path) +} + +// TestMiddlewareTestSuit runs all the tests within this package. +func TestMiddlewareTestSuit(t *testing.T) { + suite.Run(t, new(MiddlewareTestSuite)) +} + +func (suite *MiddlewareTestSuite) TestAnteDecorator() { + testCases := []struct { + name string + malleate func(suite *MiddlewareTestSuite) []sdk.Msg + expPass bool + }{ + { + "success on single msg", + func(suite *MiddlewareTestSuite) []sdk.Msg { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), 1, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + return []sdk.Msg{channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")} + }, + true, + }, + { + "success on multiple msgs", + func(suite *MiddlewareTestSuite) []sdk.Msg { + var msgs []sdk.Msg + + for i := 1; i <= 5; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + return msgs + }, + true, + }, + { + "success on multiple msgs: 1 fresh recv packet", + func(suite *MiddlewareTestSuite) []sdk.Msg { + var msgs []sdk.Msg + + for i := 1; i <= 5; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + err := suite.path.EndpointA.SendPacket(packet) + suite.Require().NoError(err) + + // receive all sequences except packet 3 + if i != 3 { + err = suite.path.EndpointB.RecvPacket(packet) + suite.Require().NoError(err) + } + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + + return msgs + }, + true, + }, + { + "success on multiple mixed msgs", + func(suite *MiddlewareTestSuite) []sdk.Msg { + var msgs []sdk.Msg + + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + err := suite.path.EndpointA.SendPacket(packet) + suite.Require().NoError(err) + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + msgs = append(msgs, channeltypes.NewMsgAcknowledgement(packet, []byte("ack"), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 4; i <= 6; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + msgs = append(msgs, channeltypes.NewMsgTimeout(packet, uint64(i), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + return msgs + }, + true, + }, + { + "success on multiple mixed msgs: 1 fresh packet of each type", + func(suite *MiddlewareTestSuite) []sdk.Msg { + var msgs []sdk.Msg + + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + err := suite.path.EndpointA.SendPacket(packet) + suite.Require().NoError(err) + + // receive all sequences except packet 3 + if i != 3 { + + err := suite.path.EndpointB.RecvPacket(packet) + suite.Require().NoError(err) + } + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + // receive all acks except ack 2 + if i != 2 { + err = suite.path.EndpointA.RecvPacket(packet) + suite.Require().NoError(err) + err = suite.path.EndpointB.AcknowledgePacket(packet, mock.MockAcknowledgement.Acknowledgement()) + suite.Require().NoError(err) + } + + msgs = append(msgs, channeltypes.NewMsgAcknowledgement(packet, []byte("ack"), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 4; i <= 6; i++ { + height := suite.chainA.LastHeader.GetHeight() + timeoutHeight := clienttypes.NewHeight(height.GetRevisionNumber(), height.GetRevisionHeight()+1) + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + timeoutHeight, 0) + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + // timeout packet + suite.coordinator.CommitNBlocks(suite.chainA, 3) + + // timeout packets except sequence 5 + if i != 5 { + suite.path.EndpointB.UpdateClient() + err = suite.path.EndpointB.TimeoutPacket(packet) + suite.Require().NoError(err) + } + + msgs = append(msgs, channeltypes.NewMsgTimeout(packet, uint64(i), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + return msgs + }, + true, + }, + { + "success on multiple mixed msgs: only 1 fresh msg in total", + func(suite *MiddlewareTestSuite) []sdk.Msg { + var msgs []sdk.Msg + + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all packets + suite.path.EndpointA.SendPacket(packet) + suite.path.EndpointB.RecvPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all acks + suite.path.EndpointB.SendPacket(packet) + suite.path.EndpointA.RecvPacket(packet) + suite.path.EndpointB.AcknowledgePacket(packet, mock.MockAcknowledgement.Acknowledgement()) + + msgs = append(msgs, channeltypes.NewMsgAcknowledgement(packet, []byte("ack"), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 4; i < 5; i++ { + height := suite.chainA.LastHeader.GetHeight() + timeoutHeight := clienttypes.NewHeight(height.GetRevisionNumber(), height.GetRevisionHeight()+1) + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + timeoutHeight, 0) + + // do not timeout packet, timeout msg is fresh + suite.path.EndpointB.SendPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgTimeout(packet, uint64(i), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + return msgs + }, + true, + }, + { + "success on single update client msg", + func(suite *MiddlewareTestSuite) []sdk.Msg { + return []sdk.Msg{&clienttypes.MsgUpdateClient{}} + }, + true, + }, + { + "success on multiple update clients", + func(suite *MiddlewareTestSuite) []sdk.Msg { + return []sdk.Msg{&clienttypes.MsgUpdateClient{}, &clienttypes.MsgUpdateClient{}, &clienttypes.MsgUpdateClient{}} + }, + true, + }, + { + "success on multiple update clients and fresh packet message", + func(suite *MiddlewareTestSuite) []sdk.Msg { + msgs := []sdk.Msg{&clienttypes.MsgUpdateClient{}, &clienttypes.MsgUpdateClient{}, &clienttypes.MsgUpdateClient{}} + + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), 1, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + return append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + }, + true, + }, + { + "success of tx with different msg type even if all packet messages are redundant", + func(suite *MiddlewareTestSuite) []sdk.Msg { + msgs := []sdk.Msg{&clienttypes.MsgUpdateClient{}} + + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all packets + suite.path.EndpointA.SendPacket(packet) + suite.path.EndpointB.RecvPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all acks + suite.path.EndpointB.SendPacket(packet) + suite.path.EndpointA.RecvPacket(packet) + suite.path.EndpointB.AcknowledgePacket(packet, mock.MockAcknowledgement.Acknowledgement()) + + msgs = append(msgs, channeltypes.NewMsgAcknowledgement(packet, []byte("ack"), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 4; i < 6; i++ { + height := suite.chainA.LastHeader.GetHeight() + timeoutHeight := clienttypes.NewHeight(height.GetRevisionNumber(), height.GetRevisionHeight()+1) + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + timeoutHeight, 0) + + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + // timeout packet + suite.coordinator.CommitNBlocks(suite.chainA, 3) + + suite.path.EndpointB.UpdateClient() + suite.path.EndpointB.TimeoutPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgTimeoutOnClose(packet, uint64(i), []byte("proof"), []byte("channelProof"), clienttypes.NewHeight(0, 1), "signer")) + } + + // append non packet and update message to msgs to ensure multimsg tx should pass + msgs = append(msgs, &clienttypes.MsgSubmitMisbehaviour{}) + + return msgs + }, + true, + }, + { + "no success on multiple mixed message: all are redundant", + func(suite *MiddlewareTestSuite) []sdk.Msg { + var msgs []sdk.Msg + + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all packets + suite.path.EndpointA.SendPacket(packet) + suite.path.EndpointB.RecvPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all acks + suite.path.EndpointB.SendPacket(packet) + suite.path.EndpointA.RecvPacket(packet) + suite.path.EndpointB.AcknowledgePacket(packet, mock.MockAcknowledgement.Acknowledgement()) + + msgs = append(msgs, channeltypes.NewMsgAcknowledgement(packet, []byte("ack"), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 4; i < 6; i++ { + height := suite.chainA.LastHeader.GetHeight() + timeoutHeight := clienttypes.NewHeight(height.GetRevisionNumber(), height.GetRevisionHeight()+1) + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + timeoutHeight, 0) + + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + // timeout packet + suite.coordinator.CommitNBlocks(suite.chainA, 3) + + suite.path.EndpointB.UpdateClient() + suite.path.EndpointB.TimeoutPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgTimeoutOnClose(packet, uint64(i), []byte("proof"), []byte("channelProof"), clienttypes.NewHeight(0, 1), "signer")) + } + return msgs + }, + false, + }, + { + "no success if msgs contain update clients and redundant packet messages", + func(suite *MiddlewareTestSuite) []sdk.Msg { + msgs := []sdk.Msg{&clienttypes.MsgUpdateClient{}, &clienttypes.MsgUpdateClient{}, &clienttypes.MsgUpdateClient{}} + + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all packets + suite.path.EndpointA.SendPacket(packet) + suite.path.EndpointB.RecvPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgRecvPacket(packet, []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 1; i <= 3; i++ { + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0) + + // receive all acks + suite.path.EndpointB.SendPacket(packet) + suite.path.EndpointA.RecvPacket(packet) + suite.path.EndpointB.AcknowledgePacket(packet, mock.MockAcknowledgement.Acknowledgement()) + + msgs = append(msgs, channeltypes.NewMsgAcknowledgement(packet, []byte("ack"), []byte("proof"), clienttypes.NewHeight(0, 1), "signer")) + } + for i := 4; i < 6; i++ { + height := suite.chainA.LastHeader.GetHeight() + timeoutHeight := clienttypes.NewHeight(height.GetRevisionNumber(), height.GetRevisionHeight()+1) + packet := channeltypes.NewPacket([]byte(mock.MockPacketData), uint64(i), + suite.path.EndpointB.ChannelConfig.PortID, suite.path.EndpointB.ChannelID, + suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID, + timeoutHeight, 0) + + err := suite.path.EndpointB.SendPacket(packet) + suite.Require().NoError(err) + + // timeout packet + suite.coordinator.CommitNBlocks(suite.chainA, 3) + + suite.path.EndpointB.UpdateClient() + suite.path.EndpointB.TimeoutPacket(packet) + + msgs = append(msgs, channeltypes.NewMsgTimeoutOnClose(packet, uint64(i), []byte("proof"), []byte("channelProof"), clienttypes.NewHeight(0, 1), "signer")) + } + return msgs + }, + false, + }, + } + + for _, tc := range testCases { + tc := tc + + suite.Run(tc.name, func() { + // reset suite + suite.SetupTest() + + k := suite.chainB.App.GetIBCKeeper().ChannelKeeper + ibcMiddleware := middleware.I + + msgs := tc.malleate(suite) + + deliverCtx := suite.chainB.GetContext().WithIsCheckTx(false) + checkCtx := suite.chainB.GetContext().WithIsCheckTx(true) + + // create multimsg tx + txBuilder := suite.chainB.TxConfig.NewTxBuilder() + err := txBuilder.SetMsgs(msgs...) + suite.Require().NoError(err) + tx := txBuilder.GetTx() + + next := func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { return ctx, nil } + + _, err = decorator.AnteHandle(deliverCtx, tx, false, next) + suite.Require().NoError(err, "antedecorator should not error on DeliverTx") + + _, err = decorator.AnteHandle(checkCtx, tx, false, next) + if tc.expPass { + suite.Require().NoError(err, "non-strict decorator did not pass as expected") + } else { + suite.Require().Error(err, "non-strict antehandler did not return error as expected") + } + }) + } +} From 93c0fea20313a2483b87b31d4f4a1fb5dc4c5294 Mon Sep 17 00:00:00 2001 From: vuong Date: Fri, 25 Feb 2022 22:22:21 +0700 Subject: [PATCH 2/3] go.mod add lazyledger/smt --- go.mod | 1 - go.sum | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 86f5969859b..8751d95def5 100644 --- a/go.mod +++ b/go.mod @@ -91,7 +91,6 @@ require ( github.com/jmhodges/levigo v1.0.0 // indirect github.com/keybase/go-keychain v0.0.0-20190712205309-48d3d31d256d // indirect github.com/klauspost/compress v1.13.6 // indirect - github.com/lazyledger/smt v0.2.1-0.20210709230900-03ea40719554 // indirect github.com/lib/pq v1.10.4 // indirect github.com/libp2p/go-buffer-pool v0.0.2 // indirect github.com/magiconair/properties v1.8.5 // indirect diff --git a/go.sum b/go.sum index 63f2234491b..018ce67a495 100644 --- a/go.sum +++ b/go.sum @@ -261,7 +261,6 @@ github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:z github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= github.com/coinbase/rosetta-sdk-go v0.7.2 h1:uCNrASIyt7rV9bA3gzPG3JDlxVP5v/zLgi01GWngncM= github.com/coinbase/rosetta-sdk-go v0.7.2/go.mod h1:wk9dvjZFSZiWSNkFuj3dMleTA1adLFotg5y71PhqKB4= -github.com/confio/ics23/go v0.6.6 h1:pkOy18YxxJ/r0XFDCnrl4Bjv6h4LkBSpLS6F38mrKL8= github.com/confio/ics23/go v0.6.6/go.mod h1:E45NqnlpxGnpfTWL/xauN7MRwEE28T4Dd4uraToOaKg= github.com/confio/ics23/go v0.7.0-rc h1:cH2I3xkPE6oD4tP5pmZDAfYq8V7VeXCr98X1MpARTaI= github.com/confio/ics23/go v0.7.0-rc/go.mod h1:E45NqnlpxGnpfTWL/xauN7MRwEE28T4Dd4uraToOaKg= @@ -288,8 +287,6 @@ github.com/cosmos/btcutil v1.0.4/go.mod h1:Ffqc8Hn6TJUdDgHBwIZLtrLQC1KdJ9jGJl/Tv github.com/cosmos/cosmos-proto v1.0.0-alpha7/go.mod h1:dosO4pSAbJF8zWCzCoTWP7nNsjcvSUBQmniFxDg5daw= github.com/cosmos/cosmos-proto v1.0.0-alpha7.0.20220208174455-213b76899fac h1:xdD9S2oFjpSb1sdIRM5LXgQ2SWIQTXKY7xosnp+w1LA= github.com/cosmos/cosmos-proto v1.0.0-alpha7.0.20220208174455-213b76899fac/go.mod h1:TUpFbyjtKZ+dXmJ48nAVio0lw6dMXMsWhEa1ln4iSw4= -github.com/cosmos/cosmos-sdk v0.46.0-alpha2.0.20220218134704-20e17ea71a9b h1:pezVz0v436NEHzo0RmaAlWEXUumAbsLpixvyOqW5jMY= -github.com/cosmos/cosmos-sdk v0.46.0-alpha2.0.20220218134704-20e17ea71a9b/go.mod h1:JySmURTWPCf0NeuCGeaLtSmkp3H0FUCBxRAptRmz+dw= github.com/cosmos/cosmos-sdk v0.46.0-alpha2.0.20220222235041-afbb0bd1941f h1:UmPFQGiK7fjt1Woe77NmILLnKTr8cDV6pGZlOHyygRo= github.com/cosmos/cosmos-sdk v0.46.0-alpha2.0.20220222235041-afbb0bd1941f/go.mod h1:w8r5e1R7DvZy48qBhAerzXxKC4wH3f5262y8F9RxkMA= github.com/cosmos/cosmos-sdk/api v0.1.0-alpha4 h1:z2si9sQNUTj2jw+24SujuUxcoNS3TVga/fdYsS4rJII= @@ -758,10 +755,13 @@ github.com/jedisct1/go-minisign v0.0.0-20190909160543-45766022959e/go.mod h1:G1C github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jgautheron/goconst v1.5.1/go.mod h1:aAosetZ5zaeC/2EfMeRswtxUFBpe2Hr7HzkgX4fanO4= +github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= +github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= +github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= github.com/jhump/protoreflect v1.6.1/go.mod h1:RZQ/lnuN+zqeRVpQigTwO6o0AJUkxbnSnpuG7toUTG4= -github.com/jhump/protoreflect v1.11.0 h1:bvACHUD1Ua/3VxY4aAMpItKMhhwbimlKFJKsLsVgDjU= github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E= github.com/jhump/protoreflect v1.12.0 h1:1NQ4FpWMgn3by/n1X0fbeKEUxP1wBt7+Oitpv01HR10= +github.com/jhump/protoreflect v1.12.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI= github.com/jingyugao/rowserrcheck v1.1.1/go.mod h1:4yvlZSDb3IyDTUZJUmpZfm2Hwok+Dtp+nu2qOq+er9c= github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af/go.mod h1:HEWGJkRDzjJY2sqdDwxccsGicWEf9BQOZsq2tV+xzM0= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= @@ -836,7 +836,6 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+ github.com/kyoh86/exportloopref v0.1.8/go.mod h1:1tUcJeiioIs7VWe5gcOObrux3lb66+sBqGZrRkMwPgg= github.com/labstack/echo/v4 v4.2.1/go.mod h1:AA49e0DZ8kk5jTOOCKNuPR6oTnBS0dYiM4FW1e6jwpg= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= -github.com/lazyledger/smt v0.2.1-0.20210709230900-03ea40719554 h1:nDOkLO7klmnEw1s4AyKt1Arvpgyh33uj1JmkYlJaDsk= github.com/lazyledger/smt v0.2.1-0.20210709230900-03ea40719554/go.mod h1:9+Pb2/tg1PvEgW7aFx4bFhDE4bvbI03zuJ8kb7nJ9Jc= github.com/ldez/gomoddirectives v0.2.2/go.mod h1:cpgBogWITnCfRq2qGoDkKMEVSaarhdBr6g8G04uz6d0= github.com/ldez/tagliatelle v0.3.0/go.mod h1:8s6WJQwEYHbKZDsp/LjArytKOG8qaMrKQQ3mFukHs88= From 28c86c88eefb3de9e56760d2aa20c1178a3fdb2d Mon Sep 17 00:00:00 2001 From: vuong Date: Fri, 25 Feb 2022 22:22:40 +0700 Subject: [PATCH 3/3] add middleware test && pass test --- modules/core/middleware/middleware_test.go | 59 ++++++++++++++++------ 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/modules/core/middleware/middleware_test.go b/modules/core/middleware/middleware_test.go index 72d847a8ab6..68ce3abeab1 100644 --- a/modules/core/middleware/middleware_test.go +++ b/modules/core/middleware/middleware_test.go @@ -1,16 +1,18 @@ -package middleware_test +package middleware import ( + "context" "testing" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/tx" + txtypes "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/x/auth/middleware" - "github.com/stretchr/testify/suite" - clienttypes "github.com/cosmos/ibc-go/v3/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v3/modules/core/04-channel/types" ibctesting "github.com/cosmos/ibc-go/v3/testing" "github.com/cosmos/ibc-go/v3/testing/mock" + "github.com/stretchr/testify/suite" ) type MiddlewareTestSuite struct { @@ -23,6 +25,12 @@ type MiddlewareTestSuite struct { chainB *ibctesting.TestChain path *ibctesting.Path + + txHandler txtypes.Handler +} + +//set TxHandler for test +func (suite *MiddlewareTestSuite) setupTxHandler() { } // SetupTest creates a coordinator with 2 test chains. @@ -455,32 +463,51 @@ func (suite *MiddlewareTestSuite) TestAnteDecorator() { suite.Run(tc.name, func() { // reset suite suite.SetupTest() - - k := suite.chainB.App.GetIBCKeeper().ChannelKeeper - ibcMiddleware := middleware.I - msgs := tc.malleate(suite) - deliverCtx := suite.chainB.GetContext().WithIsCheckTx(false) checkCtx := suite.chainB.GetContext().WithIsCheckTx(true) + txHandler := middleware.ComposeMiddlewares(noopTxHandler, IbcTxMiddleware(suite.chainB.App.GetIBCKeeper().ChannelKeeper)) // create multimsg tx txBuilder := suite.chainB.TxConfig.NewTxBuilder() err := txBuilder.SetMsgs(msgs...) suite.Require().NoError(err) - tx := txBuilder.GetTx() + testTx := txBuilder.GetTx() - next := func(ctx sdk.Context, tx sdk.Tx, simulate bool) (newCtx sdk.Context, err error) { return ctx, nil } + // test DeliverTx + _, err = txHandler.DeliverTx(sdk.WrapSDKContext(deliverCtx), txtypes.Request{Tx: testTx}) + suite.Require().NoError(err, "should not error on DeliverTx") - _, err = decorator.AnteHandle(deliverCtx, tx, false, next) - suite.Require().NoError(err, "antedecorator should not error on DeliverTx") - - _, err = decorator.AnteHandle(checkCtx, tx, false, next) + // test CheckTx + _, _, err = txHandler.CheckTx(sdk.WrapSDKContext(checkCtx), txtypes.Request{Tx: testTx}, txtypes.RequestCheckTx{}) if tc.expPass { - suite.Require().NoError(err, "non-strict decorator did not pass as expected") + suite.Require().NoError(err, "did not pass as expected") } else { - suite.Require().Error(err, "non-strict antehandler did not return error as expected") + suite.Require().Error(err, "did not return error as expected") } }) } } + +// customTxHandler is a test middleware that will run a custom function. +type customTxHandler struct { + fn func(context.Context, tx.Request) (tx.Response, error) +} + +var _ tx.Handler = customTxHandler{} + +func (h customTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) { + return h.fn(ctx, req) +} +func (h customTxHandler) CheckTx(ctx context.Context, req tx.Request, _ tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) { + res, err := h.fn(ctx, req) + return res, tx.ResponseCheckTx{}, err +} +func (h customTxHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) { + return h.fn(ctx, req) +} + +// noopTxHandler is a test middleware that returns an empty response. +var noopTxHandler = customTxHandler{func(_ context.Context, _ tx.Request) (tx.Response, error) { + return tx.Response{}, nil +}}