diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7cc53feb5a5..8f76a660095 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -40,3 +40,7 @@ # CODEOWNERS for docs /docs/ @colin-axner @AdityaSripal @crodriguezvega @charleenfei @damiannolan @chatton @DimitrisJim @srdtrk + +# CODEOWNERS for callbacks middleware + +/modules/apps/callbacks/ @colin-axner @AdityaSripal @damiannolan @srdtrk diff --git a/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go b/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go index 0bf4af8496a..7d61f3e0618 100644 --- a/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go +++ b/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go @@ -13,9 +13,9 @@ import ( controllerkeeper "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/controller/keeper" "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/controller/types" icatypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/types" - fee "github.com/cosmos/ibc-go/v7/modules/apps/29-fee" clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" host "github.com/cosmos/ibc-go/v7/modules/core/24-host" ibctesting "github.com/cosmos/ibc-go/v7/testing" ) @@ -836,7 +836,7 @@ func (suite *InterchainAccountsTestSuite) TestGetAppVersion() { cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) suite.Require().True(ok) - controllerStack := cbs.(fee.IBCMiddleware) + controllerStack := cbs.(porttypes.Middleware) appVersion, found := controllerStack.GetAppVersion(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) suite.Require().True(found) suite.Require().Equal(path.EndpointA.ChannelConfig.Version, appVersion) diff --git a/modules/apps/27-interchain-accounts/controller/keeper/keeper_test.go b/modules/apps/27-interchain-accounts/controller/keeper/keeper_test.go index 3fc9b5c2e84..b46ba39e2f2 100644 --- a/modules/apps/27-interchain-accounts/controller/keeper/keeper_test.go +++ b/modules/apps/27-interchain-accounts/controller/keeper/keeper_test.go @@ -11,7 +11,6 @@ import ( "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/controller/types" genesistypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/genesis/types" icatypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/types" - ibcfeekeeper "github.com/cosmos/ibc-go/v7/modules/apps/29-fee/keeper" channelkeeper "github.com/cosmos/ibc-go/v7/modules/core/04-channel/keeper" channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" ibctesting "github.com/cosmos/ibc-go/v7/testing" @@ -298,11 +297,9 @@ func (suite *KeeperTestSuite) TestGetAuthority() { func (suite *KeeperTestSuite) TestWithICS4Wrapper() { suite.SetupTest() - // test if the ics4 wrapper is the fee keeper initially + // test if the ics4 wrapper is the channel keeper initially ics4Wrapper := suite.chainA.GetSimApp().ICAControllerKeeper.GetICS4Wrapper() - _, isFeeKeeper := ics4Wrapper.(ibcfeekeeper.Keeper) - suite.Require().True(isFeeKeeper) _, isChannelKeeper := ics4Wrapper.(channelkeeper.Keeper) suite.Require().False(isChannelKeeper) @@ -312,6 +309,4 @@ func (suite *KeeperTestSuite) TestWithICS4Wrapper() { _, isChannelKeeper = ics4Wrapper.(channelkeeper.Keeper) suite.Require().True(isChannelKeeper) - _, isFeeKeeper = ics4Wrapper.(ibcfeekeeper.Keeper) - suite.Require().False(isFeeKeeper) } diff --git a/modules/apps/29-fee/ibc_middleware_test.go b/modules/apps/29-fee/ibc_middleware_test.go index b2a740a06e5..9ce8d8abe1f 100644 --- a/modules/apps/29-fee/ibc_middleware_test.go +++ b/modules/apps/29-fee/ibc_middleware_test.go @@ -1069,7 +1069,7 @@ func (suite *FeeTestSuite) TestGetAppVersion() { cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) suite.Require().True(ok) - feeModule := cbs.(ibcfee.IBCMiddleware) + feeModule := cbs.(porttypes.Middleware) appVersion, found := feeModule.GetAppVersion(suite.chainA.GetContext(), portID, channelID) diff --git a/modules/apps/callbacks/callbacks_test.go b/modules/apps/callbacks/callbacks_test.go new file mode 100644 index 00000000000..58e2dd971b8 --- /dev/null +++ b/modules/apps/callbacks/callbacks_test.go @@ -0,0 +1,263 @@ +package ibccallbacks_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + + icacontrollertypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/controller/types" + icatypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/types" + feetypes "github.com/cosmos/ibc-go/v7/modules/apps/29-fee/types" + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" + ibctesting "github.com/cosmos/ibc-go/v7/testing" + ibcmock "github.com/cosmos/ibc-go/v7/testing/mock" +) + +const maxCallbackGas = uint64(1000000) + +// CallbacksTestSuite defines the needed instances and methods to test callbacks +type CallbacksTestSuite struct { + suite.Suite + + coordinator *ibctesting.Coordinator + + chainA *ibctesting.TestChain + chainB *ibctesting.TestChain + + path *ibctesting.Path +} + +// setupChains sets up a coordinator with 2 test chains. +func (s *CallbacksTestSuite) setupChains() { + s.coordinator = ibctesting.NewCoordinator(s.T(), 2) + s.chainA = s.coordinator.GetChain(ibctesting.GetChainID(1)) + s.chainB = s.coordinator.GetChain(ibctesting.GetChainID(2)) + s.path = ibctesting.NewPath(s.chainA, s.chainB) +} + +// SetupTransferTest sets up a transfer channel between chainA and chainB +func (s *CallbacksTestSuite) SetupTransferTest() { + s.setupChains() + + s.path.EndpointA.ChannelConfig.PortID = ibctesting.TransferPort + s.path.EndpointB.ChannelConfig.PortID = ibctesting.TransferPort + s.path.EndpointA.ChannelConfig.Version = transfertypes.Version + s.path.EndpointB.ChannelConfig.Version = transfertypes.Version + + s.coordinator.Setup(s.path) +} + +// SetupFeeTransferTest sets up a fee middleware enabled transfer channel between chainA and chainB +func (s *CallbacksTestSuite) SetupFeeTransferTest() { + s.setupChains() + + feeTransferVersion := string(feetypes.ModuleCdc.MustMarshalJSON(&feetypes.Metadata{FeeVersion: feetypes.Version, AppVersion: transfertypes.Version})) + s.path.EndpointA.ChannelConfig.Version = feeTransferVersion + s.path.EndpointB.ChannelConfig.Version = feeTransferVersion + s.path.EndpointA.ChannelConfig.PortID = transfertypes.PortID + s.path.EndpointB.ChannelConfig.PortID = transfertypes.PortID + + s.coordinator.Setup(s.path) +} + +func (s *CallbacksTestSuite) SetupMockFeeTest() { + s.setupChains() + + mockFeeVersion := string(feetypes.ModuleCdc.MustMarshalJSON(&feetypes.Metadata{FeeVersion: feetypes.Version, AppVersion: ibcmock.Version})) + s.path.EndpointA.ChannelConfig.Version = mockFeeVersion + s.path.EndpointB.ChannelConfig.Version = mockFeeVersion + s.path.EndpointA.ChannelConfig.PortID = ibctesting.MockFeePort + s.path.EndpointB.ChannelConfig.PortID = ibctesting.MockFeePort +} + +// SetupICATest sets up an interchain accounts channel between chainA (controller) and chainB (host). +// It funds and returns the interchain account address owned by chainA's SenderAccount. +func (s *CallbacksTestSuite) SetupICATest() string { + s.setupChains() + s.coordinator.SetupConnections(s.path) + + icaOwner := s.chainA.SenderAccount.GetAddress().String() + // ICAVersion defines a interchain accounts version string + icaVersion := icatypes.NewDefaultMetadataString(s.path.EndpointA.ConnectionID, s.path.EndpointB.ConnectionID) + icaControllerPortID, err := icatypes.NewControllerPortID(icaOwner) + s.Require().NoError(err) + + s.path.SetChannelOrdered() + s.path.EndpointA.ChannelConfig.PortID = icaControllerPortID + s.path.EndpointB.ChannelConfig.PortID = icatypes.HostPortID + s.path.EndpointA.ChannelConfig.Version = icaVersion + s.path.EndpointB.ChannelConfig.Version = icaVersion + + s.RegisterInterchainAccount(icaOwner) + // open chan init must be skipped. So we cannot use .CreateChannels() + err = s.path.EndpointB.ChanOpenTry() + s.Require().NoError(err) + err = s.path.EndpointA.ChanOpenAck() + s.Require().NoError(err) + err = s.path.EndpointB.ChanOpenConfirm() + s.Require().NoError(err) + + interchainAccountAddr, found := s.chainB.GetSimApp().ICAHostKeeper.GetInterchainAccountAddress(s.chainB.GetContext(), s.path.EndpointA.ConnectionID, s.path.EndpointA.ChannelConfig.PortID) + s.Require().True(found) + + // fund the interchain account on chainB + msgBankSend := &banktypes.MsgSend{ + FromAddress: s.chainB.SenderAccount.GetAddress().String(), + ToAddress: interchainAccountAddr, + Amount: sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(100000))), + } + res, err := s.chainB.SendMsgs(msgBankSend) + s.Require().NotEmpty(res) + s.Require().NoError(err) + + return interchainAccountAddr +} + +// RegisterInterchainAccount submits a MsgRegisterInterchainAccount and updates the controller endpoint with the +// channel created. +func (s *CallbacksTestSuite) RegisterInterchainAccount(owner string) { + msgRegister := icacontrollertypes.NewMsgRegisterInterchainAccount(s.path.EndpointA.ConnectionID, owner, s.path.EndpointA.ChannelConfig.Version) + + res, err := s.chainA.SendMsgs(msgRegister) + s.Require().NotEmpty(res) + s.Require().NoError(err) + + channelID, err := ibctesting.ParseChannelIDFromEvents(res.Events) + s.Require().NoError(err) + + s.path.EndpointA.ChannelID = channelID +} + +// AssertHasExecutedExpectedCallback checks the stateful entries and counters based on callbacktype. +// It assumes that the source chain is chainA and the destination chain is chainB. +func (s *CallbacksTestSuite) AssertHasExecutedExpectedCallback(callbackType types.CallbackType, expSuccess bool) { + var expStatefulEntries uint8 + if expSuccess { + // if the callback is expected to be successful, + // we expect at least one state entry + expStatefulEntries = 1 + } + + sourceStatefulCounter := s.chainA.GetSimApp().MockContractKeeper.GetStateEntryCounter(s.chainA.GetContext()) + destStatefulCounter := s.chainB.GetSimApp().MockContractKeeper.GetStateEntryCounter(s.chainB.GetContext()) + + switch callbackType { + case "none": + s.Require().Equal(uint8(0), sourceStatefulCounter) + s.Require().Equal(uint8(0), destStatefulCounter) + + case types.CallbackTypeSendPacket: + s.Require().Equal(expStatefulEntries, sourceStatefulCounter, "unexpected stateful entry amount for source send packet callback") + s.Require().Equal(uint8(0), destStatefulCounter) + + case types.CallbackTypeAcknowledgementPacket, types.CallbackTypeTimeoutPacket: + expStatefulEntries *= 2 // expect OnAcknowledgement/OnTimeout to be successful as well as the initial SendPacket + s.Require().Equal(expStatefulEntries, sourceStatefulCounter, "unexpected stateful entry amount for source acknowledgement/timeout callbacks") + s.Require().Equal(uint8(0), destStatefulCounter) + + case types.CallbackTypeReceivePacket: + s.Require().Equal(uint8(0), sourceStatefulCounter) + s.Require().Equal(expStatefulEntries, destStatefulCounter) + + default: + s.FailNow(fmt.Sprintf("invalid callback type %s", callbackType)) + } + + s.AssertCallbackCounters(callbackType) +} + +func (s *CallbacksTestSuite) AssertCallbackCounters(callbackType types.CallbackType) { + sourceCounters := s.chainA.GetSimApp().MockContractKeeper.Counters + destCounters := s.chainB.GetSimApp().MockContractKeeper.Counters + + switch callbackType { + case "none": + s.Require().Len(sourceCounters, 0) + s.Require().Len(destCounters, 0) + + case types.CallbackTypeSendPacket: + s.Require().Len(sourceCounters, 1) + s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket]) + + case types.CallbackTypeAcknowledgementPacket: + s.Require().Len(sourceCounters, 2) + s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket]) + s.Require().Equal(1, sourceCounters[types.CallbackTypeAcknowledgementPacket]) + + s.Require().Len(destCounters, 0) + + case types.CallbackTypeReceivePacket: + s.Require().Len(sourceCounters, 0) + s.Require().Len(destCounters, 1) + s.Require().Equal(1, destCounters[types.CallbackTypeReceivePacket]) + + case types.CallbackTypeTimeoutPacket: + s.Require().Len(sourceCounters, 2) + s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket]) + s.Require().Equal(1, sourceCounters[types.CallbackTypeTimeoutPacket]) + + s.Require().Len(destCounters, 0) + + default: + s.FailNow(fmt.Sprintf("invalid callback type %s", callbackType)) + } +} + +func TestIBCCallbacksTestSuite(t *testing.T) { + suite.Run(t, new(CallbacksTestSuite)) +} + +// AssertHasExecutedExpectedCallbackWithFee checks if only the expected type of callback has been executed +// and that the expected ics-29 fee has been paid. +func (s *CallbacksTestSuite) AssertHasExecutedExpectedCallbackWithFee( + callbackType types.CallbackType, isSuccessful bool, isTimeout bool, + originalSenderBalance sdk.Coins, fee feetypes.Fee, +) { + // Recall that: + // - the source chain is chainA + // - forward relayer is chainB.SenderAccount + // - reverse relayer is chainA.SenderAccount + // - The counterparty payee of the forward relayer in chainA is chainB.SenderAccount (as a chainA account) + + // We only check if the fee is paid if the callback is successful. + if !isTimeout && isSuccessful { + // check forward relay balance + s.Require().Equal( + fee.RecvFee, + sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainB.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom)), + ) + + s.Require().Equal( + fee.AckFee.Add(fee.TimeoutFee...), // ack fee paid, timeout fee refunded + sdk.NewCoins( + s.chainA.GetSimApp().BankKeeper.GetBalance( + s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), + ibctesting.TestCoin.Denom), + ).Sub(originalSenderBalance[0]), + ) + } else if isSuccessful { + // forward relay balance should be 0 + s.Require().Equal( + sdk.NewCoin(ibctesting.TestCoin.Denom, sdkmath.ZeroInt()), + s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainB.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom), + ) + + // all fees should be returned as sender is the reverse relayer + s.Require().Equal( + fee.Total(), + sdk.NewCoins( + s.chainA.GetSimApp().BankKeeper.GetBalance( + s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), + ibctesting.TestCoin.Denom), + ).Sub(originalSenderBalance[0]), + ) + } + s.AssertHasExecutedExpectedCallback(callbackType, isSuccessful) +} diff --git a/modules/apps/callbacks/export_test.go b/modules/apps/callbacks/export_test.go new file mode 100644 index 00000000000..b7ea323910f --- /dev/null +++ b/modules/apps/callbacks/export_test.go @@ -0,0 +1,26 @@ +package ibccallbacks + +/* + This file is to allow for unexported functions and fields to be accessible to the testing package. +*/ + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" +) + +// ProcessCallback is a wrapper around processCallback to allow the function to be directly called in tests. +func (im IBCMiddleware) ProcessCallback( + ctx sdk.Context, packet channeltypes.Packet, callbackType types.CallbackType, + callbackData types.CallbackData, callbackExecutor func(sdk.Context) error, +) error { + return im.processCallback(ctx, packet, callbackType, callbackData, callbackExecutor) +} + +// GetICS4Wrapper is a getter for the IBCMiddleware's ICS4Wrapper. +func (im *IBCMiddleware) GetICS4Wrapper() porttypes.ICS4Wrapper { + return im.ics4Wrapper +} diff --git a/modules/apps/callbacks/fee_transfer_test.go b/modules/apps/callbacks/fee_transfer_test.go new file mode 100644 index 00000000000..4e448d4a77d --- /dev/null +++ b/modules/apps/callbacks/fee_transfer_test.go @@ -0,0 +1,197 @@ +package ibccallbacks_test + +import ( + "fmt" + + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + + feetypes "github.com/cosmos/ibc-go/v7/modules/apps/29-fee/types" + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + ibctesting "github.com/cosmos/ibc-go/v7/testing" +) + +var ( + defaultRecvFee = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(100)}} + defaultAckFee = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(200)}} + defaultTimeoutFee = sdk.Coins{sdk.Coin{Denom: sdk.DefaultBondDenom, Amount: sdkmath.NewInt(300)}} +) + +func (s *CallbacksTestSuite) TestIncentivizedTransferCallbacks() { + testCases := []struct { + name string + transferMemo string + expCallback types.CallbackType + expSuccess bool + }{ + { + "success: transfer with no memo", + "", + "none", + true, + }, + { + "success: dest callback", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeReceivePacket, + true, + }, + { + "success: dest callback with other json fields", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}, "something_else": {}}`, callbackAddr), + types.CallbackTypeReceivePacket, + true, + }, + { + "success: dest callback with malformed json", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}, malformed}`, callbackAddr), + "none", + true, + }, + { + "success: dest callback with missing address", + `{"dest_callback": {"address": ""}}`, + "none", + true, + }, + { + "success: source callback", + fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeAcknowledgementPacket, + true, + }, + { + "success: source callback with other json fields", + fmt.Sprintf(`{"src_callback": {"address": "%s"}, "something_else": {}}`, callbackAddr), + types.CallbackTypeAcknowledgementPacket, + true, + }, + { + "success: source callback with malformed json", + fmt.Sprintf(`{"src_callback": {"address": "%s"}, malformed}`, callbackAddr), + "none", + true, + }, + { + "success: source callback with missing address", + `{"src_callback": {"address": ""}}`, + "none", + true, + }, + { + "failure: dest callback with low gas (panic)", + fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + types.CallbackTypeReceivePacket, + false, + }, + { + "failure: source callback with low gas (panic)", + fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + types.CallbackTypeSendPacket, + false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + s.SetupFeeTransferTest() + + fee := feetypes.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + + s.ExecutePayPacketFeeMsg(fee) + preRelaySenderBalance := sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom)) + s.ExecuteTransfer(tc.transferMemo) + // we manually subtract the transfer amount from the preRelaySenderBalance because ExecuteTransfer + // also relays the packet, which will trigger the fee payments. + preRelaySenderBalance = preRelaySenderBalance.Sub(ibctesting.TestCoin) + + // after incentivizing the packets + s.AssertHasExecutedExpectedCallbackWithFee(tc.expCallback, tc.expSuccess, false, preRelaySenderBalance, fee) + }) + } +} + +func (s *CallbacksTestSuite) TestIncentivizedTransferTimeoutCallbacks() { + testCases := []struct { + name string + transferMemo string + expCallback types.CallbackType + expSuccess bool + }{ + { + "success: transfer with no memo", + "", + "none", + true, + }, + { + "success: dest callback", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, callbackAddr), + "none", + true, // timeouts don't reach destination chain execution + }, + { + "success: source callback", + fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeTimeoutPacket, + true, + }, + { + "success: dest callback with low gas (panic)", + fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + "none", // timeouts don't reach destination chain execution + false, + }, + { + "failure: source callback with low gas (panic)", + fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + types.CallbackTypeSendPacket, + false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + s.SetupFeeTransferTest() + + fee := feetypes.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee) + + s.ExecutePayPacketFeeMsg(fee) + preRelaySenderBalance := sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom)) + s.ExecuteTransferTimeout(tc.transferMemo, 1) + + // after incentivizing the packets + s.AssertHasExecutedExpectedCallbackWithFee(tc.expCallback, tc.expSuccess, true, preRelaySenderBalance, fee) + }) + } +} + +func (s *CallbacksTestSuite) ExecutePayPacketFeeMsg(fee feetypes.Fee) { + msg := feetypes.NewMsgPayPacketFee( + fee, s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID, + s.chainA.SenderAccount.GetAddress().String(), nil, + ) + + // fetch the account balance before fees are escrowed and assert the difference below + preEscrowBalance := s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + + res, err := s.chainA.SendMsgs(msg) + s.Require().NoError(err) + s.Require().NotNil(res) + + postEscrowBalance := s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + s.Require().Equal(postEscrowBalance.AddAmount(fee.Total().AmountOf(sdk.DefaultBondDenom)), preEscrowBalance) + + // register counterparty address on chainB + // relayerAddress is address of sender account on chainB, but we will use it on chainA + // to differentiate from the chainA.SenderAccount for checking successful relay payouts + relayerAddress := s.chainB.SenderAccount.GetAddress() + + msgRegister := feetypes.NewMsgRegisterCounterpartyPayee( + s.path.EndpointB.ChannelConfig.PortID, s.path.EndpointB.ChannelID, + s.chainB.SenderAccount.GetAddress().String(), relayerAddress.String(), + ) + _, err = s.chainB.SendMsgs(msgRegister) + s.Require().NoError(err) // message committed +} diff --git a/modules/apps/callbacks/ibc_middleware.go b/modules/apps/callbacks/ibc_middleware.go new file mode 100644 index 00000000000..4328e32b738 --- /dev/null +++ b/modules/apps/callbacks/ibc_middleware.go @@ -0,0 +1,347 @@ +package ibccallbacks + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" +) + +var ( + _ porttypes.Middleware = (*IBCMiddleware)(nil) + _ porttypes.PacketDataUnmarshaler = (*IBCMiddleware)(nil) +) + +// IBCMiddleware implements the ICS26 callbacks for the ibc-callbacks middleware given +// the underlying application. +type IBCMiddleware struct { + app types.CallbacksCompatibleModule + ics4Wrapper porttypes.ICS4Wrapper + + contractKeeper types.ContractKeeper + + // maxCallbackGas defines the maximum amount of gas that a callback actor can ask the + // relayer to pay for. If a callback fails due to insufficient gas, the entire tx + // is reverted if the relayer hadn't provided the minimum(userDefinedGas, maxCallbackGas). + // If the actor hasn't defined a gas limit, then it is assumed to be the maxCallbackGas. + maxCallbackGas uint64 +} + +// NewIBCMiddleware creates a new IBCMiddlware given the keeper and underlying application. +// The underlying application must implement the required callback interfaces. +func NewIBCMiddleware( + app porttypes.IBCModule, ics4Wrapper porttypes.ICS4Wrapper, + contractKeeper types.ContractKeeper, maxCallbackGas uint64, +) IBCMiddleware { + packetDataUnmarshalerApp, ok := app.(types.CallbacksCompatibleModule) + if !ok { + panic(fmt.Errorf("underlying application does not implement %T", (*types.CallbacksCompatibleModule)(nil))) + } + + if ics4Wrapper == nil { + panic(fmt.Errorf("ICS4Wrapper cannot be nil")) + } + + if contractKeeper == nil { + panic(fmt.Errorf("contract keeper cannot be nil")) + } + + if maxCallbackGas == 0 { + panic(fmt.Errorf("maxCallbackGas cannot be zero")) + } + + return IBCMiddleware{ + app: packetDataUnmarshalerApp, + ics4Wrapper: ics4Wrapper, + contractKeeper: contractKeeper, + maxCallbackGas: maxCallbackGas, + } +} + +// WithICS4Wrapper sets the ICS4Wrapper. This function may be used after the +// middleware's creation to set the middleware which is above this module in +// the IBC application stack. +func (im *IBCMiddleware) WithICS4Wrapper(wrapper porttypes.ICS4Wrapper) { + im.ics4Wrapper = wrapper +} + +// SendPacket implements source callbacks for sending packets. +// It defers to the underlying application and then calls the contract callback. +// If the contract callback returns an error, panics, or runs out of gas, then +// the packet send is rejected. +func (im IBCMiddleware) SendPacket( + ctx sdk.Context, + chanCap *capabilitytypes.Capability, + sourcePort string, + sourceChannel string, + timeoutHeight clienttypes.Height, + timeoutTimestamp uint64, + data []byte, +) (uint64, error) { + seq, err := im.ics4Wrapper.SendPacket(ctx, chanCap, sourcePort, sourceChannel, timeoutHeight, timeoutTimestamp, data) + if err != nil { + return 0, err + } + + // Reconstruct the sent packet. The destination portID and channelID are intentionally left empty as the sender information + // is only derived from the source packet information in `GetSourceCallbackData`. + reconstructedPacket := channeltypes.NewPacket(data, seq, sourcePort, sourceChannel, "", "", timeoutHeight, timeoutTimestamp) + + callbackData, err := types.GetSourceCallbackData(im.app, reconstructedPacket, ctx.GasMeter().GasRemaining(), im.maxCallbackGas) + // SendPacket is not blocked if the packet does not opt-in to callbacks + if err != nil { + return seq, nil + } + + callbackExecutor := func(cachedCtx sdk.Context) error { + return im.contractKeeper.IBCSendPacketCallback( + cachedCtx, sourcePort, sourceChannel, timeoutHeight, timeoutTimestamp, data, callbackData.CallbackAddress, callbackData.SenderAddress, + ) + } + + err = im.processCallback(ctx, reconstructedPacket, types.CallbackTypeSendPacket, callbackData, callbackExecutor) + // contract keeper is allowed to reject the packet send. + if err != nil { + return 0, err + } + + types.EmitCallbackEvent(ctx, reconstructedPacket, types.CallbackTypeSendPacket, callbackData, nil) + return seq, nil +} + +// OnAcknowledgementPacket implements source callbacks for acknowledgement packets. +// It defers to the underlying application and then calls the contract callback. +// If the contract callback runs out of gas and may be retried with a higher gas limit then the state changes are +// reverted via a panic. +func (im IBCMiddleware) OnAcknowledgementPacket( + ctx sdk.Context, + packet channeltypes.Packet, + acknowledgement []byte, + relayer sdk.AccAddress, +) error { + // we first call the underlying app to handle the acknowledgement + err := im.app.OnAcknowledgementPacket(ctx, packet, acknowledgement, relayer) + if err != nil { + return err + } + + callbackData, err := types.GetSourceCallbackData(im.app, packet, ctx.GasMeter().GasRemaining(), im.maxCallbackGas) + // OnAcknowledgementPacket is not blocked if the packet does not opt-in to callbacks + if err != nil { + return nil + } + + callbackExecutor := func(cachedCtx sdk.Context) error { + return im.contractKeeper.IBCOnAcknowledgementPacketCallback( + cachedCtx, packet, acknowledgement, relayer, callbackData.CallbackAddress, callbackData.SenderAddress, + ) + } + + // callback execution errors are not allowed to block the packet lifecycle, they are only used in event emissions + err = im.processCallback(ctx, packet, types.CallbackTypeAcknowledgementPacket, callbackData, callbackExecutor) + types.EmitCallbackEvent(ctx, packet, types.CallbackTypeAcknowledgementPacket, callbackData, err) + + return nil +} + +// OnTimeoutPacket implements timeout source callbacks for the ibc-callbacks middleware. +// It defers to the underlying application and then calls the contract callback. +// If the contract callback runs out of gas and may be retried with a higher gas limit then the state changes are +// reverted via a panic. +func (im IBCMiddleware) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet, relayer sdk.AccAddress) error { + err := im.app.OnTimeoutPacket(ctx, packet, relayer) + if err != nil { + return err + } + + callbackData, err := types.GetSourceCallbackData(im.app, packet, ctx.GasMeter().GasRemaining(), im.maxCallbackGas) + // OnTimeoutPacket is not blocked if the packet does not opt-in to callbacks + if err != nil { + return nil + } + + callbackExecutor := func(cachedCtx sdk.Context) error { + return im.contractKeeper.IBCOnTimeoutPacketCallback(cachedCtx, packet, relayer, callbackData.CallbackAddress, callbackData.SenderAddress) + } + + // callback execution errors are not allowed to block the packet lifecycle, they are only used in event emissions + err = im.processCallback(ctx, packet, types.CallbackTypeTimeoutPacket, callbackData, callbackExecutor) + types.EmitCallbackEvent(ctx, packet, types.CallbackTypeTimeoutPacket, callbackData, err) + + return nil +} + +// OnRecvPacket implements the ReceivePacket destination callbacks for the ibc-callbacks middleware during +// synchronous packet acknowledgement. +// It defers to the underlying application and then calls the contract callback. +// If the contract callback runs out of gas and may be retried with a higher gas limit then the state changes are +// reverted via a panic. +func (im IBCMiddleware) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, relayer sdk.AccAddress) ibcexported.Acknowledgement { + ack := im.app.OnRecvPacket(ctx, packet, relayer) + // if ack is nil (asynchronous acknowledgements), then the callback will be handled in WriteAcknowledgement + // if ack is not successful, all state changes are reverted. If a packet cannot be received, then there is + // no need to execute a callback on the receiving chain. + if ack == nil || !ack.Success() { + return ack + } + + callbackData, err := types.GetDestCallbackData(im.app, packet, ctx.GasMeter().GasRemaining(), im.maxCallbackGas) + // OnRecvPacket is not blocked if the packet does not opt-in to callbacks + if err != nil { + return ack + } + + callbackExecutor := func(cachedCtx sdk.Context) error { + return im.contractKeeper.IBCReceivePacketCallback(cachedCtx, packet, ack, callbackData.CallbackAddress) + } + + // callback execution errors are not allowed to block the packet lifecycle, they are only used in event emissions + err = im.processCallback(ctx, packet, types.CallbackTypeReceivePacket, callbackData, callbackExecutor) + types.EmitCallbackEvent(ctx, packet, types.CallbackTypeReceivePacket, callbackData, err) + + return ack +} + +// WriteAcknowledgement implements the ReceivePacket destination callbacks for the ibc-callbacks middleware +// during asynchronous packet acknowledgement. +// It defers to the underlying application and then calls the contract callback. +// If the contract callback runs out of gas and may be retried with a higher gas limit then the state changes are +// reverted via a panic. +func (im IBCMiddleware) WriteAcknowledgement( + ctx sdk.Context, + chanCap *capabilitytypes.Capability, + packet ibcexported.PacketI, + ack ibcexported.Acknowledgement, +) error { + err := im.ics4Wrapper.WriteAcknowledgement(ctx, chanCap, packet, ack) + if err != nil { + return err + } + + callbackData, err := types.GetDestCallbackData(im.app, packet, ctx.GasMeter().GasRemaining(), im.maxCallbackGas) + // WriteAcknowledgement is not blocked if the packet does not opt-in to callbacks + if err != nil { + return nil + } + + callbackExecutor := func(cachedCtx sdk.Context) error { + return im.contractKeeper.IBCReceivePacketCallback(cachedCtx, packet, ack, callbackData.CallbackAddress) + } + + // callback execution errors are not allowed to block the packet lifecycle, they are only used in event emissions + err = im.processCallback(ctx, packet, types.CallbackTypeReceivePacket, callbackData, callbackExecutor) + types.EmitCallbackEvent(ctx, packet, types.CallbackTypeReceivePacket, callbackData, err) + + return nil +} + +// processCallback executes the callbackExecutor and reverts contract changes if the callbackExecutor fails. +// +// panics if +// - the contractExecutor panics for any reason, and the callbackType is SendPacket, or +// - the contractExecutor runs out of gas and the relayer has not reserved gas grater than or equal to +// CommitGasLimit. +func (IBCMiddleware) processCallback( + ctx sdk.Context, packet ibcexported.PacketI, callbackType types.CallbackType, + callbackData types.CallbackData, callbackExecutor func(sdk.Context) error, +) (err error) { + cachedCtx, writeFn := ctx.CacheContext() + cachedCtx = cachedCtx.WithGasMeter(sdk.NewGasMeter(callbackData.ExecutionGasLimit)) + + defer func() { + // consume the minimum of g.consumed and g.limit + ctx.GasMeter().ConsumeGas(cachedCtx.GasMeter().GasConsumedToLimit(), fmt.Sprintf("ibc %s callback", callbackType)) + + // recover from all panics except during SendPacket callbacks + if r := recover(); r != nil { + if callbackType == types.CallbackTypeSendPacket { + panic(r) + } + } + + // if the callback ran out of gas and the relayer has not reserved enough gas, then revert the state + if cachedCtx.GasMeter().IsPastLimit() && callbackData.AllowRetry() { + panic(sdk.ErrorOutOfGas{Descriptor: fmt.Sprintf("ibc %s callback out of gas; commitGasLimit: %d", callbackType, callbackData.CommitGasLimit)}) + } + + // allow the transaction to be committed, continuing the packet lifecycle + }() + + err = callbackExecutor(cachedCtx) + if err == nil { + writeFn() + } + + return err +} + +// OnChanOpenInit defers to the underlying application +func (im IBCMiddleware) OnChanOpenInit( + ctx sdk.Context, + channelOrdering channeltypes.Order, + connectionHops []string, + portID, + channelID string, + channelCap *capabilitytypes.Capability, + counterparty channeltypes.Counterparty, + version string, +) (string, error) { + return im.app.OnChanOpenInit(ctx, channelOrdering, connectionHops, portID, channelID, channelCap, counterparty, version) +} + +// OnChanOpenTry defers to the underlying application +func (im IBCMiddleware) OnChanOpenTry( + ctx sdk.Context, + channelOrdering channeltypes.Order, + connectionHops []string, portID, + channelID string, + channelCap *capabilitytypes.Capability, + counterparty channeltypes.Counterparty, + counterpartyVersion string, +) (string, error) { + return im.app.OnChanOpenTry(ctx, channelOrdering, connectionHops, portID, channelID, channelCap, counterparty, counterpartyVersion) +} + +// OnChanOpenAck defers to the underlying application +func (im IBCMiddleware) OnChanOpenAck( + ctx sdk.Context, + portID, + channelID, + counterpartyChannelID, + counterpartyVersion string, +) error { + return im.app.OnChanOpenAck(ctx, portID, channelID, counterpartyChannelID, counterpartyVersion) +} + +// OnChanOpenConfirm defers to the underlying application +func (im IBCMiddleware) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string) error { + return im.app.OnChanOpenConfirm(ctx, portID, channelID) +} + +// OnChanCloseInit defers to the underlying application +func (im IBCMiddleware) OnChanCloseInit(ctx sdk.Context, portID, channelID string) error { + return im.app.OnChanCloseInit(ctx, portID, channelID) +} + +// OnChanCloseConfirm defers to the underlying application +func (im IBCMiddleware) OnChanCloseConfirm(ctx sdk.Context, portID, channelID string) error { + return im.app.OnChanCloseConfirm(ctx, portID, channelID) +} + +// GetAppVersion implements the ICS4Wrapper interface. Callbacks has no version, +// so the call is deferred to the underlying application. +func (im IBCMiddleware) GetAppVersion(ctx sdk.Context, portID, channelID string) (string, bool) { + return im.ics4Wrapper.GetAppVersion(ctx, portID, channelID) +} + +// UnmarshalPacketData defers to the underlying app to unmarshal the packet data. +// This function implements the optional PacketDataUnmarshaler interface. +func (im IBCMiddleware) UnmarshalPacketData(bz []byte) (interface{}, error) { + return im.app.UnmarshalPacketData(bz) +} diff --git a/modules/apps/callbacks/ibc_middleware_test.go b/modules/apps/callbacks/ibc_middleware_test.go new file mode 100644 index 00000000000..af30b87aac3 --- /dev/null +++ b/modules/apps/callbacks/ibc_middleware_test.go @@ -0,0 +1,965 @@ +package ibccallbacks_test + +import ( + "fmt" + + errorsmod "cosmossdk.io/errors" + + sdk "github.com/cosmos/cosmos-sdk/types" + + icacontrollertypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/controller/types" + ibccallbacks "github.com/cosmos/ibc-go/v7/modules/apps/callbacks" + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" + channelkeeper "github.com/cosmos/ibc-go/v7/modules/core/04-channel/keeper" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" + ibcerrors "github.com/cosmos/ibc-go/v7/modules/core/errors" + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" + ibctesting "github.com/cosmos/ibc-go/v7/testing" + ibcmock "github.com/cosmos/ibc-go/v7/testing/mock" +) + +func (s *CallbacksTestSuite) TestNewIBCMiddleware() { + testCases := []struct { + name string + instantiateFn func() + expError error + }{ + { + "success", + func() { + _ = ibccallbacks.NewIBCMiddleware(ibcmock.IBCModule{}, channelkeeper.Keeper{}, ibcmock.ContractKeeper{}, maxCallbackGas) + }, + nil, + }, + { + "panics with nil underlying app", + func() { + _ = ibccallbacks.NewIBCMiddleware(nil, channelkeeper.Keeper{}, ibcmock.ContractKeeper{}, maxCallbackGas) + }, + fmt.Errorf("underlying application does not implement %T", (*types.CallbacksCompatibleModule)(nil)), + }, + { + "panics with nil contract keeper", + func() { + _ = ibccallbacks.NewIBCMiddleware(ibcmock.IBCModule{}, channelkeeper.Keeper{}, nil, maxCallbackGas) + }, + fmt.Errorf("contract keeper cannot be nil"), + }, + { + "panics with nil ics4Wrapper", + func() { + _ = ibccallbacks.NewIBCMiddleware(ibcmock.IBCModule{}, nil, ibcmock.ContractKeeper{}, maxCallbackGas) + }, + fmt.Errorf("ICS4Wrapper cannot be nil"), + }, + { + "panics with zero maxCallbackGas", + func() { + _ = ibccallbacks.NewIBCMiddleware(ibcmock.IBCModule{}, channelkeeper.Keeper{}, ibcmock.ContractKeeper{}, uint64(0)) + }, + fmt.Errorf("maxCallbackGas cannot be zero"), + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + expPass := tc.expError == nil + if expPass { + s.Require().NotPanics(tc.instantiateFn, "unexpected panic: NewIBCMiddleware") + } else { + s.Require().PanicsWithError(tc.expError.Error(), tc.instantiateFn, "expected panic with error: ", tc.expError.Error()) + } + }) + } +} + +func (s *CallbacksTestSuite) TestWithICS4Wrapper() { + s.setupChains() + + cbsMiddleware := ibccallbacks.IBCMiddleware{} + s.Require().Nil(cbsMiddleware.GetICS4Wrapper()) + + cbsMiddleware.WithICS4Wrapper(s.chainA.App.GetIBCKeeper().ChannelKeeper) + ics4Wrapper := cbsMiddleware.GetICS4Wrapper() + + s.Require().IsType(channelkeeper.Keeper{}, ics4Wrapper) +} + +func (s *CallbacksTestSuite) TestSendPacket() { + var packetData transfertypes.FungibleTokenPacketData + + testCases := []struct { + name string + malleate func() + callbackType types.CallbackType + expPanic bool + expValue interface{} + }{ + { + "success", + func() {}, + types.CallbackTypeSendPacket, + false, + nil, + }, + { + "success: no-op on callback data is not valid", + func() { + //nolint:goconst + packetData.Memo = `{"src_callback": {"address": ""}}` + }, + "none", // improperly formatted callback data should result in no callback execution + false, + nil, + }, + { + "failure: ics4Wrapper SendPacket call fails", + func() { + s.path.EndpointA.ChannelID = "invalid-channel" + }, + "none", // ics4wrapper failure should result in no callback execution + false, + channeltypes.ErrChannelNotFound, + }, + { + "failure: callback execution fails, sender is not callback address", + func() { + packetData.Sender = ibcmock.MockCallbackUnauthorizedAddress + }, + types.CallbackTypeSendPacket, + false, + ibcmock.MockApplicationCallbackError, // execution failure on SendPacket should prevent packet sends + }, + { + "failure: callback execution reach out of gas, but sufficient gas provided by relayer", + func() { + packetData.Memo = fmt.Sprintf(`{"src_callback": {"address":"%s", "gas_limit":"400000"}}`, callbackAddr) + }, + types.CallbackTypeSendPacket, + true, + sdk.ErrorOutOfGas{Descriptor: fmt.Sprintf("mock %s callback panic", types.CallbackTypeSendPacket)}, + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + s.SetupTransferTest() + + // callbacks module is routed as top level middleware + transferStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(transfertypes.ModuleName) + s.Require().True(ok) + + packetData = transfertypes.NewFungibleTokenPacketData( + ibctesting.TestCoin.GetDenom(), ibctesting.TestCoin.Amount.String(), callbackAddr, + ibctesting.TestAccAddress, fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + ) + + chanCap := s.path.EndpointA.Chain.GetChannelCapability(s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID) + + tc.malleate() + + var ( + seq uint64 + err error + ) + sendPacket := func() { + seq, err = transferStack.(porttypes.Middleware).SendPacket(s.chainA.GetContext(), chanCap, s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID, s.chainB.GetTimeoutHeight(), 0, packetData.GetBytes()) + } + + expPass := tc.expValue == nil + switch { + case expPass: + sendPacket() + s.Require().Nil(err) + s.Require().Equal(uint64(1), seq) + case tc.expPanic: + s.Require().PanicsWithValue(tc.expValue, sendPacket) + default: + sendPacket() + s.Require().ErrorIs(tc.expValue.(error), err) + s.Require().Equal(uint64(0), seq) + } + + s.AssertHasExecutedExpectedCallback(tc.callbackType, expPass) + }) + } +} + +func (s *CallbacksTestSuite) TestOnAcknowledgementPacket() { + type expResult uint8 + const ( + noExecution expResult = iota + callbackFailed + callbackSuccess + ) + + var ( + packetData transfertypes.FungibleTokenPacketData + packet channeltypes.Packet + ack []byte + ctx sdk.Context + ) + + panicError := fmt.Errorf("panic error") + + testCases := []struct { + name string + malleate func() + expResult expResult + expError error + }{ + { + "success", + func() {}, + callbackSuccess, + nil, + }, + { + "failure: underlying app OnAcknolwedgePacket fails", + func() { + ack = []byte("invalid ack") + }, + noExecution, + ibcerrors.ErrUnknownRequest, + }, + { + "success: no-op on callback data is not valid", + func() { + //nolint:goconst + packetData.Memo = `{"src_callback": {"address": ""}}` + packet.Data = packetData.GetBytes() + }, + noExecution, + nil, + }, + { + "failure: callback execution reach out of gas, but sufficient gas provided by relayer", + func() { + packetData.Memo = fmt.Sprintf(`{"src_callback": {"address":"%s", "gas_limit":"400000"}}`, callbackAddr) + packet.Data = packetData.GetBytes() + }, + callbackFailed, + nil, + }, + { + "failure: callback execution panics on insufficient gas provided by relayer", + func() { + ctx = ctx.WithGasMeter(sdk.NewGasMeter(300_000)) + }, + callbackFailed, + panicError, + }, + { + "failure: callback execution fails, unauthorized address", + func() { + packetData.Sender = ibcmock.MockCallbackUnauthorizedAddress + packet.Data = packetData.GetBytes() + }, + callbackFailed, + nil, // execution failure in OnAcknowledgement should not block acknowledgement processing + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + s.SetupTransferTest() + + // set user gas limit above panic level in mock contract keeper + userGasLimit := 600000 + packetData = transfertypes.NewFungibleTokenPacketData( + ibctesting.TestCoin.GetDenom(), ibctesting.TestCoin.Amount.String(), callbackAddr, ibctesting.TestAccAddress, + fmt.Sprintf(`{"src_callback": {"address":"%s", "gas_limit":"%d"}}`, callbackAddr, userGasLimit), + ) + + packet = channeltypes.Packet{ + Sequence: 1, + SourcePort: s.path.EndpointA.ChannelConfig.PortID, + SourceChannel: s.path.EndpointA.ChannelID, + DestinationPort: s.path.EndpointB.ChannelConfig.PortID, + DestinationChannel: s.path.EndpointB.ChannelID, + Data: packetData.GetBytes(), + TimeoutHeight: s.chainB.GetTimeoutHeight(), + TimeoutTimestamp: 0, + } + + ack = channeltypes.NewResultAcknowledgement([]byte{1}).Acknowledgement() + ctx = s.chainA.GetContext() + + tc.malleate() + + // callbacks module is routed as top level middleware + transferStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(transfertypes.ModuleName) + s.Require().True(ok) + + onAcknowledgementPacket := func() error { + return transferStack.OnAcknowledgementPacket(ctx, packet, ack, s.chainA.SenderAccount.GetAddress()) + } + + switch tc.expError { + case nil: + err := onAcknowledgementPacket() + s.Require().Nil(err) + + case panicError: + s.Require().PanicsWithValue(sdk.ErrorOutOfGas{ + Descriptor: fmt.Sprintf("ibc %s callback out of gas; commitGasLimit: %d", types.CallbackTypeAcknowledgementPacket, userGasLimit), + }, func() { + _ = onAcknowledgementPacket() + }) + + default: + err := onAcknowledgementPacket() + s.Require().ErrorIs(tc.expError, err) + } + + sourceStatefulCounter := s.chainA.GetSimApp().MockContractKeeper.GetStateEntryCounter(s.chainA.GetContext()) + sourceCounters := s.chainA.GetSimApp().MockContractKeeper.Counters + + switch tc.expResult { + case noExecution: + s.Require().Len(sourceCounters, 0) + s.Require().Equal(uint8(0), sourceStatefulCounter) + + case callbackFailed: + s.Require().Len(sourceCounters, 1) + s.Require().Equal(1, sourceCounters[types.CallbackTypeAcknowledgementPacket]) + s.Require().Equal(uint8(0), sourceStatefulCounter) + + case callbackSuccess: + s.Require().Len(sourceCounters, 1) + s.Require().Equal(1, sourceCounters[types.CallbackTypeAcknowledgementPacket]) + s.Require().Equal(uint8(1), sourceStatefulCounter) + + } + }) + } +} + +func (s *CallbacksTestSuite) TestOnTimeoutPacket() { + type expResult uint8 + const ( + noExecution expResult = iota + callbackFailed + callbackSuccess + ) + + var ( + packetData transfertypes.FungibleTokenPacketData + packet channeltypes.Packet + ctx sdk.Context + ) + + panicError := fmt.Errorf("panic error") + + testCases := []struct { + name string + malleate func() + expResult expResult + expError error + }{ + { + "success", + func() {}, + callbackSuccess, + nil, + }, + { + "failure: underlying app OnTimeoutPacket fails", + func() { + packet.Data = []byte("invalid packet data") + }, + noExecution, + ibcerrors.ErrUnknownRequest, + }, + { + "success: no-op on callback data is not valid", + func() { + //nolint:goconst + packetData.Memo = `{"src_callback": {"address": ""}}` + packet.Data = packetData.GetBytes() + }, + noExecution, + nil, + }, + { + "failure: callback execution reach out of gas, but sufficient gas provided by relayer", + func() { + packetData.Memo = fmt.Sprintf(`{"src_callback": {"address":"%s", "gas_limit":"400000"}}`, callbackAddr) + packet.Data = packetData.GetBytes() + }, + callbackFailed, + nil, + }, + { + "failure: callback execution panics on insufficient gas provided by relayer", + func() { + ctx = ctx.WithGasMeter(sdk.NewGasMeter(300_000)) + }, + callbackFailed, + panicError, + }, + { + "failure: callback execution fails, unauthorized address", + func() { + packetData.Sender = ibcmock.MockCallbackUnauthorizedAddress + packet.Data = packetData.GetBytes() + }, + callbackFailed, + nil, // execution failure in OnTimeout should not block timeout processing + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + s.SetupTransferTest() + + // NOTE: we call send packet so transfer is setup with the correct logic to + // succeed on timeout + userGasLimit := 600_000 + timeoutTimestamp := uint64(s.chainB.GetContext().BlockTime().UnixNano()) + msg := transfertypes.NewMsgTransfer( + s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID, + ibctesting.TestCoin, s.chainA.SenderAccount.GetAddress().String(), + s.chainB.SenderAccount.GetAddress().String(), clienttypes.ZeroHeight(), timeoutTimestamp, + fmt.Sprintf(`{"src_callback": {"address":"%s", "gas_limit":"%d"}}`, ibctesting.TestAccAddress, userGasLimit), // set user gas limit above panic level in mock contract keeper + ) + + res, err := s.chainA.SendMsgs(msg) + s.Require().NoError(err) + s.Require().NotNil(res) + + packet, err = ibctesting.ParsePacketFromEvents(res.GetEvents().ToABCIEvents()) + s.Require().NoError(err) + s.Require().NotNil(packet) + + err = transfertypes.ModuleCdc.UnmarshalJSON(packet.Data, &packetData) + s.Require().NoError(err) + + ctx = s.chainA.GetContext() + + tc.malleate() + + // callbacks module is routed as top level middleware + transferStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(transfertypes.ModuleName) + s.Require().True(ok) + + onTimeoutPacket := func() error { + return transferStack.OnTimeoutPacket(ctx, packet, s.chainA.SenderAccount.GetAddress()) + } + + switch tc.expError { + case nil: + err := onTimeoutPacket() + s.Require().Nil(err) + + case panicError: + s.Require().PanicsWithValue(sdk.ErrorOutOfGas{ + Descriptor: fmt.Sprintf("ibc %s callback out of gas; commitGasLimit: %d", types.CallbackTypeTimeoutPacket, userGasLimit), + }, func() { + _ = onTimeoutPacket() + }) + + default: + err := onTimeoutPacket() + s.Require().ErrorIs(tc.expError, err) + } + + sourceStatefulCounter := s.chainA.GetSimApp().MockContractKeeper.GetStateEntryCounter(s.chainA.GetContext()) + sourceCounters := s.chainA.GetSimApp().MockContractKeeper.Counters + + // account for SendPacket succeeding + switch tc.expResult { + case noExecution: + s.Require().Len(sourceCounters, 1) + s.Require().Equal(uint8(1), sourceStatefulCounter) + + case callbackFailed: + s.Require().Len(sourceCounters, 2) + s.Require().Equal(1, sourceCounters[types.CallbackTypeTimeoutPacket]) + s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket]) + s.Require().Equal(uint8(1), sourceStatefulCounter) + + case callbackSuccess: + s.Require().Len(sourceCounters, 2) + s.Require().Equal(1, sourceCounters[types.CallbackTypeTimeoutPacket]) + s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket]) + s.Require().Equal(uint8(2), sourceStatefulCounter) + } + }) + } +} + +func (s *CallbacksTestSuite) TestOnRecvPacket() { + type expResult uint8 + const ( + noExecution expResult = iota + callbackFailed + callbackSuccess + ) + + var ( + packetData transfertypes.FungibleTokenPacketData + packet channeltypes.Packet + ctx sdk.Context + ) + + successAck := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) + panicAck := channeltypes.NewErrorAcknowledgement(fmt.Errorf("panic")) + + testCases := []struct { + name string + malleate func() + expResult expResult + expAck ibcexported.Acknowledgement + }{ + { + "success", + func() {}, + callbackSuccess, + successAck, + }, + { + "failure: underlying app OnRecvPacket fails", + func() { + packet.Data = []byte("invalid packet data") + }, + noExecution, + channeltypes.NewErrorAcknowledgement(ibcerrors.ErrInvalidType), + }, + { + "success: no-op on callback data is not valid", + func() { + //nolint:goconst + packetData.Memo = `{"dest_callback": {"address": ""}}` + packet.Data = packetData.GetBytes() + }, + noExecution, + successAck, + }, + { + "failure: callback execution reach out of gas, but sufficient gas provided by relayer", + func() { + packetData.Memo = fmt.Sprintf(`{"dest_callback": {"address":"%s", "gas_limit":"400000"}}`, callbackAddr) + packet.Data = packetData.GetBytes() + }, + callbackFailed, + successAck, + }, + { + "failure: callback execution panics on insufficient gas provided by relayer", + func() { + ctx = ctx.WithGasMeter(sdk.NewGasMeter(300_000)) + }, + callbackFailed, + panicAck, + }, + /* + TODO: https://github.com/cosmos/ibc-go/issues/4309 + { + "failure: callback execution fails", + func() {}, + callbackFailed, + successAck, + }, + */ + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + s.SetupTransferTest() + + // set user gas limit above panic level in mock contract keeper + userGasLimit := 600_000 + packetData = transfertypes.NewFungibleTokenPacketData( + ibctesting.TestCoin.GetDenom(), ibctesting.TestCoin.Amount.String(), ibctesting.TestAccAddress, s.chainB.SenderAccount.GetAddress().String(), + fmt.Sprintf(`{"dest_callback": {"address":"%s", "gas_limit":"%d"}}`, ibctesting.TestAccAddress, userGasLimit), + ) + + packet = channeltypes.Packet{ + Sequence: 1, + SourcePort: s.path.EndpointA.ChannelConfig.PortID, + SourceChannel: s.path.EndpointA.ChannelID, + DestinationPort: s.path.EndpointB.ChannelConfig.PortID, + DestinationChannel: s.path.EndpointB.ChannelID, + Data: packetData.GetBytes(), + TimeoutHeight: s.chainB.GetTimeoutHeight(), + TimeoutTimestamp: 0, + } + + ctx = s.chainB.GetContext() + + tc.malleate() + + // callbacks module is routed as top level middleware + transferStack, ok := s.chainB.App.GetIBCKeeper().Router.GetRoute(transfertypes.ModuleName) + s.Require().True(ok) + + onRecvPacket := func() ibcexported.Acknowledgement { + return transferStack.OnRecvPacket(ctx, packet, s.chainB.SenderAccount.GetAddress()) + } + + switch tc.expAck { + case successAck: + ack := onRecvPacket() + s.Require().NotNil(ack) + + case panicAck: + s.Require().PanicsWithValue(sdk.ErrorOutOfGas{ + Descriptor: fmt.Sprintf("ibc %s callback out of gas; commitGasLimit: %d", types.CallbackTypeReceivePacket, userGasLimit), + }, func() { + _ = onRecvPacket() + }) + + default: + ack := onRecvPacket() + s.Require().Equal(tc.expAck, ack) + } + + destStatefulCounter := s.chainB.GetSimApp().MockContractKeeper.GetStateEntryCounter(s.chainB.GetContext()) + destCounters := s.chainB.GetSimApp().MockContractKeeper.Counters + + switch tc.expResult { + case noExecution: + s.Require().Len(destCounters, 0) + s.Require().Equal(uint8(0), destStatefulCounter) + + case callbackFailed: + s.Require().Len(destCounters, 1) + s.Require().Equal(1, destCounters[types.CallbackTypeReceivePacket]) + s.Require().Equal(uint8(0), destStatefulCounter) + + case callbackSuccess: + s.Require().Len(destCounters, 1) + s.Require().Equal(1, destCounters[types.CallbackTypeReceivePacket]) + s.Require().Equal(uint8(1), destStatefulCounter) + } + }) + } +} + +func (s *CallbacksTestSuite) TestWriteAcknowledgement() { + var ( + packetData transfertypes.FungibleTokenPacketData + packet channeltypes.Packet + ctx sdk.Context + ack ibcexported.Acknowledgement + ) + + successAck := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) + + testCases := []struct { + name string + malleate func() + callbackType types.CallbackType + expError error + }{ + { + "success", + func() { + ack = successAck + }, + types.CallbackTypeReceivePacket, + nil, + }, + { + "success: no-op on callback data is not valid", + func() { + packetData.Memo = `{"dest_callback": {"address": ""}}` + packet.Data = packetData.GetBytes() + }, + "none", // improperly formatted callback data should result in no callback execution + nil, + }, + { + "failure: ics4Wrapper WriteAcknowledgement call fails", + func() { + packet.DestinationChannel = "invalid-channel" + }, + "none", + channeltypes.ErrChannelNotFound, + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + s.SetupTransferTest() + + // set user gas limit above panic level in mock contract keeper + packetData = transfertypes.NewFungibleTokenPacketData( + ibctesting.TestCoin.GetDenom(), ibctesting.TestCoin.Amount.String(), ibctesting.TestAccAddress, s.chainB.SenderAccount.GetAddress().String(), + fmt.Sprintf(`{"dest_callback": {"address":"%s", "gas_limit":"600000"}}`, ibctesting.TestAccAddress), + ) + + packet = channeltypes.Packet{ + Sequence: 1, + SourcePort: s.path.EndpointA.ChannelConfig.PortID, + SourceChannel: s.path.EndpointA.ChannelID, + DestinationPort: s.path.EndpointB.ChannelConfig.PortID, + DestinationChannel: s.path.EndpointB.ChannelID, + Data: packetData.GetBytes(), + TimeoutHeight: s.chainB.GetTimeoutHeight(), + TimeoutTimestamp: 0, + } + + ctx = s.chainB.GetContext() + + chanCap := s.chainB.GetChannelCapability(s.path.EndpointB.ChannelConfig.PortID, s.path.EndpointB.ChannelID) + + tc.malleate() + + // callbacks module is routed as top level middleware + transferStack, ok := s.chainB.App.GetIBCKeeper().Router.GetRoute(transfertypes.ModuleName) + s.Require().True(ok) + + err := transferStack.(porttypes.Middleware).WriteAcknowledgement(ctx, chanCap, packet, ack) + + expPass := tc.expError == nil + s.AssertHasExecutedExpectedCallback(tc.callbackType, expPass) + + if expPass { + s.Require().NoError(err) + } else { + s.Require().ErrorIs(tc.expError, err) + } + }) + } +} + +func (s *CallbacksTestSuite) TestProcessCallback() { + var ( + callbackType types.CallbackType + callbackData types.CallbackData + ctx sdk.Context + callbackExecutor func(sdk.Context) error + ) + + callbackError := fmt.Errorf("callbackExecutor error") + + testCases := []struct { + name string + malleate func() + expPanic bool + expValue interface{} + }{ + { + "success", + func() {}, + false, + nil, + }, + { + "success: callbackExecutor panic, but not out of gas", + func() { + callbackExecutor = func(cachedCtx sdk.Context) error { + panic("callbackExecutor panic") + } + }, + false, + nil, + }, + { + "success: callbackExecutor oog panic, but retry is not allowed", + func() { + executionGas := callbackData.ExecutionGasLimit + callbackExecutor = func(cachedCtx sdk.Context) error { + cachedCtx.GasMeter().ConsumeGas(executionGas+1, "callbackExecutor oog panic") + return nil + } + }, + false, + nil, + }, + { + "failure: callbackExecutor error", + func() { + callbackExecutor = func(cachedCtx sdk.Context) error { + return callbackError + } + }, + false, + callbackError, + }, + { + "failure: callbackExecutor panic, not out of gas, and SendPacket", + func() { + callbackType = types.CallbackTypeSendPacket + callbackExecutor = func(cachedCtx sdk.Context) error { + panic("callbackExecutor panic") + } + }, + true, + "callbackExecutor panic", + }, + { + "failure: callbackExecutor oog panic, but retry is allowed", + func() { + executionGas := callbackData.ExecutionGasLimit + callbackData.CommitGasLimit = executionGas + 1 + callbackExecutor = func(cachedCtx sdk.Context) error { + cachedCtx.GasMeter().ConsumeGas(executionGas+1, "callbackExecutor oog panic") + return nil + } + }, + true, + sdk.ErrorOutOfGas{Descriptor: fmt.Sprintf("ibc %s callback out of gas; commitGasLimit: %d", types.CallbackTypeReceivePacket, 1000000+1)}, + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + s.SetupMockFeeTest() + + // set mock packet, it is only used in logs and not in callback execution + mockPacket := channeltypes.NewPacket( + ibcmock.MockPacketData, 1, s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID, + s.path.EndpointB.ChannelConfig.PortID, s.path.EndpointB.ChannelID, clienttypes.NewHeight(0, 100), 0) + + // set a callback data that does not allow retry + callbackData = types.CallbackData{ + CallbackAddress: s.chainB.SenderAccount.GetAddress().String(), + ExecutionGasLimit: 1000000, + SenderAddress: s.chainB.SenderAccount.GetAddress().String(), + CommitGasLimit: 600000, + } + + // this only makes a difference if it is SendPacket + callbackType = types.CallbackTypeReceivePacket + + ctx = s.chainB.GetContext() + + // set a callback executor that will always succeed + callbackExecutor = func(cachedCtx sdk.Context) error { + return nil + } + + tc.malleate() + + module, _, err := s.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(s.chainA.GetContext(), ibctesting.MockFeePort) + s.Require().NoError(err) + cbs, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(module) + s.Require().True(ok) + mockCallbackStack, ok := cbs.(ibccallbacks.IBCMiddleware) + s.Require().True(ok) + + processCallback := func() { + err = mockCallbackStack.ProcessCallback(ctx, mockPacket, callbackType, callbackData, callbackExecutor) + } + + expPass := tc.expValue == nil + switch { + case expPass: + processCallback() + s.Require().NoError(err) + case tc.expPanic: + s.Require().PanicsWithValue(tc.expValue, processCallback) + default: + processCallback() + s.Require().ErrorIs(tc.expValue.(error), err) + } + }) + } +} + +func (s *CallbacksTestSuite) TestUnmarshalPacketData() { + s.setupChains() + + // We will pass the function call down the transfer stack to the transfer module + // transfer stack UnmarshalPacketData call order: callbacks -> fee -> transfer + transferStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(transfertypes.ModuleName) + s.Require().True(ok) + + unmarshalerStack, ok := transferStack.(types.CallbacksCompatibleModule) + s.Require().True(ok) + + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: ibctesting.TestAccAddress, + Receiver: ibctesting.TestAccAddress, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}, "dest_callback": {"address":"%s"}}`, ibctesting.TestAccAddress, ibctesting.TestAccAddress), + } + data := expPacketData.GetBytes() + + packetData, err := unmarshalerStack.UnmarshalPacketData(data) + s.Require().NoError(err) + s.Require().Equal(expPacketData, packetData) +} + +func (s *CallbacksTestSuite) TestGetAppVersion() { + s.SetupICATest() + + // Obtain an IBC stack for testing. The function call will use the top of the stack which calls + // directly to the channel keeper. Calling from a further down module in the stack is not necessary + // for this test. + icaControllerStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(icacontrollertypes.SubModuleName) + s.Require().True(ok) + + controllerStack := icaControllerStack.(porttypes.Middleware) + appVersion, found := controllerStack.GetAppVersion(s.chainA.GetContext(), s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID) + s.Require().True(found) + s.Require().Equal(s.path.EndpointA.ChannelConfig.Version, appVersion) +} + +func (s *CallbacksTestSuite) TestOnChanCloseInit() { + s.SetupICATest() + + // We will pass the function call down the icacontroller stack to the icacontroller module + // icacontroller stack OnChanCloseInit call order: callbacks -> fee -> icacontroller + icaControllerStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(icacontrollertypes.SubModuleName) + s.Require().True(ok) + + controllerStack := icaControllerStack.(porttypes.Middleware) + err := controllerStack.OnChanCloseInit(s.chainA.GetContext(), s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID) + // we just check that this call is passed down to the icacontroller to return an error + s.Require().ErrorIs(errorsmod.Wrap(ibcerrors.ErrInvalidRequest, "user cannot close channel"), err) +} + +func (s *CallbacksTestSuite) TestOnChanCloseConfirm() { + s.SetupICATest() + + // We will pass the function call down the icacontroller stack to the icacontroller module + // icacontroller stack OnChanCloseConfirm call order: callbacks -> fee -> icacontroller + icaControllerStack, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(icacontrollertypes.SubModuleName) + s.Require().True(ok) + + controllerStack := icaControllerStack.(porttypes.Middleware) + err := controllerStack.OnChanCloseConfirm(s.chainA.GetContext(), s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID) + // we just check that this call is passed down to the icacontroller + s.Require().NoError(err) +} + +func (s *CallbacksTestSuite) TestOnRecvPacketAsyncAck() { + s.SetupMockFeeTest() + + module, _, err := s.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(s.chainA.GetContext(), ibctesting.MockFeePort) + s.Require().NoError(err) + cbs, ok := s.chainA.App.GetIBCKeeper().Router.GetRoute(module) + s.Require().True(ok) + mockFeeCallbackStack, ok := cbs.(porttypes.Middleware) + s.Require().True(ok) + + packet := channeltypes.NewPacket( + ibcmock.MockAsyncPacketData, + s.chainA.SenderAccount.GetSequence(), + s.path.EndpointA.ChannelConfig.PortID, + s.path.EndpointA.ChannelID, + s.path.EndpointB.ChannelConfig.PortID, + s.path.EndpointB.ChannelID, + clienttypes.NewHeight(0, 100), + 0, + ) + + ack := mockFeeCallbackStack.OnRecvPacket(s.chainA.GetContext(), packet, s.chainA.SenderAccount.GetAddress()) + s.Require().Nil(ack) + s.AssertHasExecutedExpectedCallback("none", true) +} diff --git a/modules/apps/callbacks/ica_test.go b/modules/apps/callbacks/ica_test.go new file mode 100644 index 00000000000..758336f0dd2 --- /dev/null +++ b/modules/apps/callbacks/ica_test.go @@ -0,0 +1,197 @@ +package ibccallbacks_test + +import ( + "fmt" + "time" + + "github.com/cosmos/gogoproto/proto" + + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + + icacontrollertypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/controller/types" + icahosttypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/host/types" + icatypes "github.com/cosmos/ibc-go/v7/modules/apps/27-interchain-accounts/types" + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + ibctesting "github.com/cosmos/ibc-go/v7/testing" +) + +func (s *CallbacksTestSuite) TestICACallbacks() { + // Destination callbacks are not supported for ICA packets + testCases := []struct { + name string + icaMemo string + expCallback types.CallbackType + expSuccess bool + }{ + { + "success: send ica tx with no memo", + "", + "none", + true, + }, + { + "success: dest callback", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, callbackAddr), + "none", + true, + }, + { + "success: source callback", + fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeAcknowledgementPacket, + true, + }, + { + "success: source callback with other json fields", + fmt.Sprintf(`{"src_callback": {"address": "%s"}, "something_else": {}}`, callbackAddr), + types.CallbackTypeAcknowledgementPacket, + true, + }, + { + "success: source callback with malformed json", + fmt.Sprintf(`{"src_callback": {"address": "%s"}, malformed}`, callbackAddr), + "none", + true, + }, + { + "success: source callback with missing address", + `{"src_callback": {"address": ""}}`, + "none", + true, + }, + { + "failure: source callback with low gas (panic)", + fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "350000"}}`, callbackAddr), + types.CallbackTypeSendPacket, + false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + icaAddr := s.SetupICATest() + + s.ExecuteICATx(icaAddr, tc.icaMemo, 1) + s.AssertHasExecutedExpectedCallback(tc.expCallback, tc.expSuccess) + }) + } +} + +func (s *CallbacksTestSuite) TestICATimeoutCallbacks() { + // ICA channels are closed after a timeout packet is executed + testCases := []struct { + name string + icaMemo string + expCallback types.CallbackType + expSuccess bool + }{ + { + "success: send ica tx timeout with no memo", + "", + "none", + true, + }, + { + "success: dest callback", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, callbackAddr), + "none", + true, + }, + { + "success: source callback", + fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeTimeoutPacket, + true, + }, + { + "failure: source callback with low gas (panic)", + fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "350000"}}`, callbackAddr), + types.CallbackTypeSendPacket, + false, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + icaAddr := s.SetupICATest() + + s.ExecuteICATimeout(icaAddr, tc.icaMemo, 1) + s.AssertHasExecutedExpectedCallback(tc.expCallback, tc.expSuccess) + }) + } +} + +// ExecuteICATx executes a stakingtypes.MsgDelegate on chainB by sending a packet containing the msg to chainB +func (s *CallbacksTestSuite) ExecuteICATx(icaAddress, memo string, seq uint64) { + timeoutTimestamp := uint64(s.chainA.GetContext().BlockTime().Add(time.Minute).UnixNano()) + icaOwner := s.chainA.SenderAccount.GetAddress().String() + connectionID := s.path.EndpointA.ConnectionID + // build the interchain accounts packet data + packetData := s.buildICAMsgDelegatePacketData(icaAddress, memo) + msg := icacontrollertypes.NewMsgSendTx(icaOwner, connectionID, timeoutTimestamp, packetData) + + res, err := s.chainA.SendMsgs(msg) + if err != nil { + return // we return if send packet is rejected + } + + packet, err := ibctesting.ParsePacketFromEvents(res.GetEvents().ToABCIEvents()) + s.Require().NoError(err) + + err = s.path.RelayPacket(packet) + s.Require().NoError(err) +} + +// ExecuteICATx sends and times out an ICA tx +func (s *CallbacksTestSuite) ExecuteICATimeout(icaAddress, memo string, seq uint64) { + relativeTimeout := uint64(1) + icaOwner := s.chainA.SenderAccount.GetAddress().String() + connectionID := s.path.EndpointA.ConnectionID + // build the interchain accounts packet data + packetData := s.buildICAMsgDelegatePacketData(icaAddress, memo) + msg := icacontrollertypes.NewMsgSendTx(icaOwner, connectionID, relativeTimeout, packetData) + + res, err := s.chainA.SendMsgs(msg) + if err != nil { + return // we return if send packet is rejected + } + + packet, err := ibctesting.ParsePacketFromEvents(res.GetEvents().ToABCIEvents()) + s.Require().NoError(err) + + // proof query requires up to date client + err = s.path.EndpointA.UpdateClient() + s.Require().NoError(err) + + err = s.path.EndpointA.TimeoutPacket(packet) + s.Require().NoError(err) +} + +// buildICAMsgDelegatePacketData builds a packetData containing a stakingtypes.MsgDelegate to be executed on chainB +func (s *CallbacksTestSuite) buildICAMsgDelegatePacketData(icaAddress string, memo string) icatypes.InterchainAccountPacketData { + // prepare a simple stakingtypes.MsgDelegate to be used as the interchain account msg executed on chainB + validatorAddr := (sdk.ValAddress)(s.chainB.Vals.Validators[0].Address) + msgDelegate := &stakingtypes.MsgDelegate{ + DelegatorAddress: icaAddress, + ValidatorAddress: validatorAddr.String(), + Amount: sdk.NewCoin(sdk.DefaultBondDenom, sdkmath.NewInt(5000)), + } + + // ensure chainB is allowed to execute stakingtypes.MsgDelegate + params := icahosttypes.NewParams(true, []string{sdk.MsgTypeURL(msgDelegate)}) + s.chainB.GetSimApp().ICAHostKeeper.SetParams(s.chainB.GetContext(), params) + + data, err := icatypes.SerializeCosmosTx(s.chainA.GetSimApp().AppCodec(), []proto.Message{msgDelegate}, icatypes.EncodingProtobuf) + s.Require().NoError(err) + + icaPacketData := icatypes.InterchainAccountPacketData{ + Type: icatypes.EXECUTE_TX, + Data: data, + Memo: memo, + } + + return icaPacketData +} diff --git a/modules/apps/callbacks/transfer_test.go b/modules/apps/callbacks/transfer_test.go new file mode 100644 index 00000000000..bfe55691216 --- /dev/null +++ b/modules/apps/callbacks/transfer_test.go @@ -0,0 +1,215 @@ +package ibccallbacks_test + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" + ibctesting "github.com/cosmos/ibc-go/v7/testing" +) + +var callbackAddr = ibctesting.TestAccAddress + +func (s *CallbacksTestSuite) TestTransferCallbacks() { + testCases := []struct { + name string + transferMemo string + expCallback types.CallbackType + expSuccess bool + }{ + { + "success: transfer with no memo", + "", + "none", + true, + }, + { + "success: dest callback", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeReceivePacket, + true, + }, + { + "success: dest callback with other json fields", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}, "something_else": {}}`, callbackAddr), + types.CallbackTypeReceivePacket, + true, + }, + { + "success: dest callback with malformed json", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}, malformed}`, callbackAddr), + "none", + true, + }, + { + "success: dest callback with missing address", + `{"dest_callback": {"address": ""}}`, + "none", + true, + }, + { + "success: source callback", + fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeAcknowledgementPacket, + true, + }, + { + "success: source callback with other json fields", + fmt.Sprintf(`{"src_callback": {"address": "%s"}, "something_else": {}}`, callbackAddr), + types.CallbackTypeAcknowledgementPacket, + true, + }, + { + "success: source callback with malformed json", + fmt.Sprintf(`{"src_callback": {"address": "%s"}, malformed}`, callbackAddr), + "none", + true, + }, + { + "success: source callback with missing address", + `{"src_callback": {"address": ""}}`, + "none", + true, + }, + { + "failure: dest callback with low gas (panic)", + fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + types.CallbackTypeReceivePacket, + false, + }, + { + "failure: source callback with low gas (panic)", + fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + types.CallbackTypeSendPacket, + false, + }, + } + + for _, tc := range testCases { + s.SetupTransferTest() + + s.ExecuteTransfer(tc.transferMemo) + s.AssertHasExecutedExpectedCallback(tc.expCallback, tc.expSuccess) + } +} + +func (s *CallbacksTestSuite) TestTransferTimeoutCallbacks() { + testCases := []struct { + name string + transferMemo string + expCallback types.CallbackType + expSuccess bool + }{ + { + "success: transfer with no memo", + "", + "none", + true, + }, + { + "success: dest callback", + fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, callbackAddr), + "none", // timeouts don't reach destination chain execution + true, + }, + { + "success: source callback", + fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, callbackAddr), + types.CallbackTypeTimeoutPacket, + true, + }, + { + "success: dest callback with low gas (panic)", + fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + "none", // timeouts don't reach destination chain execution + true, + }, + { + "failure: source callback with low gas (panic)", + fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "450000"}}`, callbackAddr), + types.CallbackTypeSendPacket, + false, + }, + } + + for _, tc := range testCases { + s.SetupTransferTest() + + s.ExecuteTransferTimeout(tc.transferMemo, 1) + s.AssertHasExecutedExpectedCallback(tc.expCallback, tc.expSuccess) + } +} + +// ExecuteTransfer executes a transfer message on chainA for ibctesting.TestCoin (100 "stake"). +// It checks that the transfer is successful and that the packet is relayed to chainB. +func (s *CallbacksTestSuite) ExecuteTransfer(memo string) { + escrowAddress := transfertypes.GetEscrowAddress(s.path.EndpointA.ChannelConfig.PortID, s.path.EndpointA.ChannelID) + // record the balance of the escrow address before the transfer + escrowBalance := s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), escrowAddress, sdk.DefaultBondDenom) + // record the balance of the receiving address before the transfer + voucherDenomTrace := transfertypes.ParseDenomTrace(transfertypes.GetPrefixedDenom(s.path.EndpointB.ChannelConfig.PortID, s.path.EndpointB.ChannelID, sdk.DefaultBondDenom)) + receiverBalance := s.chainB.GetSimApp().BankKeeper.GetBalance(s.chainB.GetContext(), s.chainB.SenderAccount.GetAddress(), voucherDenomTrace.IBCDenom()) + + amount := ibctesting.TestCoin + msg := transfertypes.NewMsgTransfer( + s.path.EndpointA.ChannelConfig.PortID, + s.path.EndpointA.ChannelID, + amount, + s.chainA.SenderAccount.GetAddress().String(), + s.chainB.SenderAccount.GetAddress().String(), + clienttypes.NewHeight(1, 100), 0, memo, + ) + + res, err := s.chainA.SendMsgs(msg) + if err != nil { + return // we return if send packet is rejected + } + + packet, err := ibctesting.ParsePacketFromEvents(res.GetEvents().ToABCIEvents()) + s.Require().NoError(err) + + // relay send + err = s.path.RelayPacket(packet) + s.Require().NoError(err) // relay committed + + // check that the escrow address balance increased by 100 + s.Require().Equal(escrowBalance.Add(amount), s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), escrowAddress, sdk.DefaultBondDenom)) + // check that the receiving address balance increased by 100 + s.Require().Equal(receiverBalance.AddAmount(sdk.NewInt(100)), s.chainB.GetSimApp().BankKeeper.GetBalance(s.chainB.GetContext(), s.chainB.SenderAccount.GetAddress(), voucherDenomTrace.IBCDenom())) +} + +// ExecuteTransferTimeout executes a transfer message on chainA for 100 denom. +// This message is not relayed to chainB, and it times out on chainA. +func (s *CallbacksTestSuite) ExecuteTransferTimeout(memo string, nextSeqRecv uint64) { + timeoutHeight := clienttypes.GetSelfHeight(s.chainB.GetContext()) + timeoutTimestamp := uint64(s.chainB.GetContext().BlockTime().UnixNano()) + + amount := ibctesting.TestCoin + msg := transfertypes.NewMsgTransfer( + s.path.EndpointA.ChannelConfig.PortID, + s.path.EndpointA.ChannelID, + amount, + s.chainA.SenderAccount.GetAddress().String(), + s.chainB.SenderAccount.GetAddress().String(), + timeoutHeight, timeoutTimestamp, memo, + ) + + res, err := s.chainA.SendMsgs(msg) + if err != nil { + return // we return if send packet is rejected + } + + packet, err := ibctesting.ParsePacketFromEvents(res.GetEvents().ToABCIEvents()) + s.Require().NoError(err) // packet committed + s.Require().NotNil(packet) + + // need to update chainA's client representing chainB to prove missing ack + err = s.path.EndpointA.UpdateClient() + s.Require().NoError(err) + + err = s.path.EndpointA.TimeoutPacket(packet) + s.Require().NoError(err) // timeout committed +} diff --git a/modules/apps/callbacks/types/callbacks.go b/modules/apps/callbacks/types/callbacks.go new file mode 100644 index 00000000000..a868245422a --- /dev/null +++ b/modules/apps/callbacks/types/callbacks.go @@ -0,0 +1,198 @@ +package types + +import ( + "strconv" + "strings" + + errorsmod "cosmossdk.io/errors" + + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" +) + +/* + +ADR-8 implementation + +The Memo is used to ensure that the callback is desired by the user. This allows a user to send a packet to an ADR-8 enabled contract. + +The Memo format is defined like so: + +```json +{ + // ... other memo fields we don't care about + "src_callback": { + "address": {stringCallbackAddress}, + + // optional fields + "gas_limit": {stringForCallback} + }, + "dest_callback": { + "address": {stringCallbackAddress}, + + // optional fields + "gas_limit": {stringForCallback} + } +} +``` + +We will pass the packet sender info (if available) to the contract keeper for source callback executions. This will allow the contract +keeper to verify that the packet sender is the same as the callback address if desired. + +*/ + +// CallbacksCompatibleModule is an interface that combines the IBCModule and PacketDataUnmarshaler +// interfaces to assert that the underlying application supports both. +type CallbacksCompatibleModule interface { + porttypes.IBCModule + porttypes.PacketDataUnmarshaler +} + +// CallbackData is the callback data parsed from the packet. +type CallbackData struct { + // CallbackAddress is the address of the callback actor. + CallbackAddress string + // ExecutionGasLimit is the gas limit which will be used for the callback execution. + ExecutionGasLimit uint64 + // SenderAddress is the sender of the packet. This is passed to the contract keeper + // to verify that the packet sender is the same as the callback address if desired. + // This address is empty during destination callback execution. + // This address may be empty if the sender is unknown or undefined. + SenderAddress string + // CommitGasLimit is the gas needed to commit the callback even if the callback + // execution fails due to out of gas. + // This parameter is only used in event emissions, or logging. + CommitGasLimit uint64 +} + +// GetSourceCallbackData parses the packet data and returns the source callback data. +func GetSourceCallbackData( + packetDataUnmarshaler porttypes.PacketDataUnmarshaler, + packet ibcexported.PacketI, remainingGas uint64, maxGas uint64, +) (CallbackData, error) { + return getCallbackData(packetDataUnmarshaler, packet, remainingGas, maxGas, SourceCallbackKey) +} + +// GetDestCallbackData parses the packet data and returns the destination callback data. +func GetDestCallbackData( + packetDataUnmarshaler porttypes.PacketDataUnmarshaler, + packet ibcexported.PacketI, remainingGas uint64, maxGas uint64, +) (CallbackData, error) { + return getCallbackData(packetDataUnmarshaler, packet, remainingGas, maxGas, DestinationCallbackKey) +} + +// getCallbackData parses the packet data and returns the callback data. +// It also checks that the remaining gas is greater than the gas limit specified in the packet data. +// The addressGetter and gasLimitGetter functions are used to retrieve the callback +// address and gas limit from the callback data. +func getCallbackData( + packetDataUnmarshaler porttypes.PacketDataUnmarshaler, + packet ibcexported.PacketI, remainingGas, + maxGas uint64, callbackKey string, +) (CallbackData, error) { + // unmarshal packet data + unmarshaledData, err := packetDataUnmarshaler.UnmarshalPacketData(packet.GetData()) + if err != nil { + return CallbackData{}, errorsmod.Wrap(ErrCannotUnmarshalPacketData, err.Error()) + } + + packetDataProvider, ok := unmarshaledData.(ibcexported.PacketDataProvider) + if !ok { + return CallbackData{}, ErrNotPacketDataProvider + } + + callbackData, ok := packetDataProvider.GetCustomPacketData(callbackKey).(map[string]interface{}) + if callbackData == nil || !ok { + return CallbackData{}, ErrCallbackKeyNotFound + } + + // get the callback address from the callback data + callbackAddress := getCallbackAddress(callbackData) + if strings.TrimSpace(callbackAddress) == "" { + return CallbackData{}, ErrCallbackAddressNotFound + } + + // retrieve packet sender from packet data if possible and if needed + var packetSender string + if callbackKey == SourceCallbackKey { + packetData, ok := unmarshaledData.(ibcexported.PacketData) + if ok { + packetSender = packetData.GetPacketSender(packet.GetSourcePort()) + } + } + + // get the gas limit from the callback data + executionGasLimit, commitGasLimit := computeExecAndCommitGasLimit(callbackData, remainingGas, maxGas) + + return CallbackData{ + CallbackAddress: callbackAddress, + ExecutionGasLimit: executionGasLimit, + SenderAddress: packetSender, + CommitGasLimit: commitGasLimit, + }, nil +} + +func computeExecAndCommitGasLimit(callbackData map[string]interface{}, remainingGas, maxGas uint64) (uint64, uint64) { + // get the gas limit from the callback data + commitGasLimit := getUserDefinedGasLimit(callbackData) + + // ensure user defined gas limit does not exceed the max gas limit + if commitGasLimit == 0 || commitGasLimit > maxGas { + commitGasLimit = maxGas + } + + // account for the remaining gas in the context being less than the desired gas limit for the callback execution + // in this case, the callback execution may be retried upon failure + executionGasLimit := commitGasLimit + if remainingGas < executionGasLimit { + executionGasLimit = remainingGas + } + + return executionGasLimit, commitGasLimit +} + +// getUserDefinedGasLimit returns the custom gas limit provided for callbacks if it is +// in the callback data. It is assumed that callback data is not nil. +// If no gas limit is specified or the gas limit is improperly formatted, 0 is returned. +// +// The memo is expected to specify the user defined gas limit in the following format: +// { "{callbackKey}": { ... , "gas_limit": {stringForCallback} } +// +// Note: the user defined gas limit must be set as a string and not a json number. +func getUserDefinedGasLimit(callbackData map[string]interface{}) uint64 { + // the gas limit must be specified as a string and not a json number + gasLimit, ok := callbackData[UserDefinedGasLimitKey].(string) + if !ok { + return 0 + } + + userGas, err := strconv.ParseUint(gasLimit, 10, 64) + if err != nil { + return 0 + } + + return userGas +} + +// getCallbackAddress returns the callback address if it is specified in the callback data. +// It is assumed that callback data is not nil. +// If no callback address is specified or the memo is improperly formatted, an empty string is returned. +// +// The memo is expected to contain the callback address in the following format: +// { "{callbackKey}": { "address": {stringCallbackAddress}} +// +// ADR-8 middleware should callback on the returned address if it is a PacketActor +// (i.e. smart contract that accepts IBC callbacks). +func getCallbackAddress(callbackData map[string]interface{}) string { + callbackAddress, ok := callbackData[CallbackAddressKey].(string) + if !ok { + return "" + } + + return callbackAddress +} + +// AllowRetry returns true if the callback execution gas limit is less than the commit gas limit. +func (c CallbackData) AllowRetry() bool { + return c.ExecutionGasLimit < c.CommitGasLimit +} diff --git a/modules/apps/callbacks/types/callbacks_test.go b/modules/apps/callbacks/types/callbacks_test.go new file mode 100644 index 00000000000..8dbddf2a228 --- /dev/null +++ b/modules/apps/callbacks/types/callbacks_test.go @@ -0,0 +1,582 @@ +package types_test + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/cometbft/cometbft/crypto/secp256k1" + + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + transfer "github.com/cosmos/ibc-go/v7/modules/apps/transfer" + transfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" + ibctesting "github.com/cosmos/ibc-go/v7/testing" + ibcmock "github.com/cosmos/ibc-go/v7/testing/mock" +) + +func (s *CallbacksTypesTestSuite) TestGetCallbackData() { + var ( + sender = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + receiver = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + packetDataUnmarshaler porttypes.PacketDataUnmarshaler + packetData []byte + remainingGas uint64 + callbackKey string + ) + + // max gas is 1_000_000 + testCases := []struct { + name string + malleate func() + expCallbackData types.CallbackData + expError error + }{ + { + "success: source callback", + func() { + remainingGas = 2_000_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: sender, + ExecutionGasLimit: 1_000_000, + CommitGasLimit: 1_000_000, + }, + nil, + }, + { + "success: destination callback", + func() { + callbackKey = types.DestinationCallbackKey + remainingGas = 2_000_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: "", + ExecutionGasLimit: 1_000_000, + CommitGasLimit: 1_000_000, + }, + nil, + }, + { + "success: destination callback with 0 user defined gas limit", + func() { + callbackKey = types.DestinationCallbackKey + remainingGas = 2_000_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit":"0"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: "", + ExecutionGasLimit: 1_000_000, + CommitGasLimit: 1_000_000, + }, + nil, + }, + { + "success: source callback with gas limit < remaining gas < max gas", + func() { + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "50000"}}`, sender), + } + packetData = expPacketData.GetBytes() + + remainingGas = 100_000 + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: sender, + ExecutionGasLimit: 50_000, + CommitGasLimit: 50_000, + }, + nil, + }, + { + "success: source callback with remaining gas < gas limit < max gas", + func() { + remainingGas = 100_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "200000"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: sender, + ExecutionGasLimit: 100_000, + CommitGasLimit: 200_000, + }, + nil, + }, + { + "success: source callback with remaining gas < max gas < gas limit", + func() { + remainingGas = 100_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "2000000"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: sender, + ExecutionGasLimit: 100_000, + CommitGasLimit: 1_000_000, + }, + nil, + }, + { + "success: destination callback with remaining gas < max gas < gas limit", + func() { + callbackKey = types.DestinationCallbackKey + remainingGas = 100_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit": "2000000"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: "", + ExecutionGasLimit: 100_000, + CommitGasLimit: 1_000_000, + }, + nil, + }, + { + "success: source callback with max gas < remaining gas < gas limit", + func() { + remainingGas = 2_000_000 + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "3000000"}}`, sender), + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{ + CallbackAddress: sender, + SenderAddress: sender, + ExecutionGasLimit: 1_000_000, + CommitGasLimit: 1_000_000, + }, + nil, + }, + { + "failure: invalid packet data", + func() { + packetData = []byte("invalid packet data") + }, + types.CallbackData{}, + types.ErrCannotUnmarshalPacketData, + }, + { + "failure: packet data does not implement PacketDataProvider", + func() { + packetData = ibcmock.MockPacketData + packetDataUnmarshaler = ibcmock.IBCModule{} + }, + types.CallbackData{}, + types.ErrNotPacketDataProvider, + }, + { + "failure: empty memo", + func() { + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: "", + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{}, + types.ErrCallbackKeyNotFound, + }, + { + "failure: empty address", + func() { + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"address": ""}}`, + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{}, + types.ErrCallbackAddressNotFound, + }, + { + "failure: space address", + func() { + expPacketData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"address": " "}}`, + } + packetData = expPacketData.GetBytes() + }, + types.CallbackData{}, + types.ErrCallbackAddressNotFound, + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + callbackKey = types.SourceCallbackKey + + packetDataUnmarshaler = transfer.IBCModule{} + + tc.malleate() + + testPacket := channeltypes.Packet{Data: packetData} + callbackData, err := types.GetCallbackData(packetDataUnmarshaler, testPacket, remainingGas, uint64(1_000_000), callbackKey) + + expPass := tc.expError == nil + if expPass { + s.Require().NoError(err, tc.name) + s.Require().Equal(tc.expCallbackData, callbackData, tc.name) + + expAllowRetry := tc.expCallbackData.ExecutionGasLimit < tc.expCallbackData.CommitGasLimit + s.Require().Equal(expAllowRetry, callbackData.AllowRetry(), tc.name) + } else { + s.Require().ErrorIs(err, tc.expError, tc.name) + } + }) + } +} + +func (s *CallbacksTypesTestSuite) TestGetSourceCallbackDataTransfer() { + sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + + packetData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender), + } + packetDataBytes := packetData.GetBytes() + + expCallbackData := types.CallbackData{ + CallbackAddress: sender, + SenderAddress: sender, + ExecutionGasLimit: 1_000_000, + CommitGasLimit: 1_000_000, + } + + packetUnmarshaler := transfer.IBCModule{} + + testPacket := channeltypes.Packet{Data: packetDataBytes} + callbackData, err := types.GetSourceCallbackData(packetUnmarshaler, testPacket, 2_000_000, 1_000_000) + s.Require().NoError(err) + s.Require().Equal(expCallbackData, callbackData) +} + +func (s *CallbacksTypesTestSuite) TestGetDestCallbackDataTransfer() { + sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + + packetData := transfertypes.FungibleTokenPacketData{ + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender), + } + packetDataBytes := packetData.GetBytes() + + expCallbackData := types.CallbackData{ + CallbackAddress: sender, + SenderAddress: "", + ExecutionGasLimit: 1_000_000, + CommitGasLimit: 1_000_000, + } + + packetUnmarshaler := transfer.IBCModule{} + + testPacket := channeltypes.Packet{Data: packetDataBytes} + callbackData, err := types.GetDestCallbackData(packetUnmarshaler, testPacket, 2_000_000, 1_000_000) + s.Require().NoError(err) + s.Require().Equal(expCallbackData, callbackData) +} + +func (s *CallbacksTypesTestSuite) TestGetCallbackAddress() { + denom := ibctesting.TestCoin.Denom + amount := ibctesting.TestCoin.Amount.String() + sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + + testCases := []struct { + name string + packetData transfertypes.FungibleTokenPacketData + expAddress string + }{ + { + "success: memo has callbacks in json struct and properly formatted src_callback_address which does not match packet sender", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, receiver), + }, + receiver, + }, + { + "success: valid src_callback address specified in memo that matches sender", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender), + }, + sender, + }, + { + "failure: memo is empty", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: "", + }, + "", + }, + { + "failure: memo is not json string", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: "memo", + }, + "", + }, + { + "failure: memo has empty src_callback object", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {}}`, + }, + "", + }, + { + "failure: memo does not have callbacks in json struct", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"Key": 10}`, + }, + "", + }, + { + "failure: memo has src_callback in json struct but does not have address key", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"Key": 10}}`, + }, + "", + }, + { + "failure: memo has src_callback in json struct but does not have string value for address key", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"address": 10}}`, + }, + "", + }, + } + + for _, tc := range testCases { + tc := tc + s.Run(tc.name, func() { + callbackData, ok := tc.packetData.GetCustomPacketData(types.SourceCallbackKey).(map[string]interface{}) + s.Require().Equal(ok, callbackData != nil) + s.Require().Equal(tc.expAddress, types.GetCallbackAddress(callbackData), tc.name) + }) + } +} + +func (s *CallbacksTypesTestSuite) TestUserDefinedGasLimit() { + denom := ibctesting.TestCoin.Denom + amount := ibctesting.TestCoin.Amount.String() + sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + + testCases := []struct { + name string + packetData transfertypes.FungibleTokenPacketData + expUserGas uint64 + }{ + { + "success: memo is empty", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: "", + }, + 0, + }, + { + "success: memo has user defined gas limit", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"gas_limit": "100"}}`, + }, + 100, + }, + { + "success: user defined gas limit is zero", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"gas_limit": "0"}}`, + }, + 0, + }, + { + "failure: memo has empty src_callback object", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {}}`, + }, + 0, + }, + { + "failure: memo has user defined gas limit as json number", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"gas_limit": 100}}`, + }, + 0, + }, + { + "failure: memo has user defined gas limit as negative", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"gas_limit": "-100"}}`, + }, + 0, + }, + { + "failure: memo has user defined gas limit as string", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"gas_limit": "invalid"}}`, + }, + 0, + }, + { + "failure: memo has user defined gas limit as empty string", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `{"src_callback": {"gas_limit": ""}}`, + }, + 0, + }, + { + "failure: malformed memo", + transfertypes.FungibleTokenPacketData{ + Denom: denom, + Amount: amount, + Sender: sender, + Receiver: receiver, + Memo: `invalid`, + }, + 0, + }, + } + + for _, tc := range testCases { + callbackData, ok := tc.packetData.GetCustomPacketData(types.SourceCallbackKey).(map[string]interface{}) + s.Require().Equal(ok, callbackData != nil) + s.Require().Equal(tc.expUserGas, types.GetUserDefinedGasLimit(callbackData), tc.name) + } +} diff --git a/modules/apps/callbacks/types/errors.go b/modules/apps/callbacks/types/errors.go new file mode 100644 index 00000000000..b1b37209625 --- /dev/null +++ b/modules/apps/callbacks/types/errors.go @@ -0,0 +1,12 @@ +package types + +import ( + errorsmod "cosmossdk.io/errors" +) + +var ( + ErrCannotUnmarshalPacketData = errorsmod.Register(ModuleName, 2, "cannot unmarshal packet data") + ErrNotPacketDataProvider = errorsmod.Register(ModuleName, 3, "packet is not a PacketDataProvider") + ErrCallbackKeyNotFound = errorsmod.Register(ModuleName, 4, "callback key not found in packet data") + ErrCallbackAddressNotFound = errorsmod.Register(ModuleName, 5, "callback address not found in packet data") +) diff --git a/modules/apps/callbacks/types/events.go b/modules/apps/callbacks/types/events.go new file mode 100644 index 00000000000..ecf255047d1 --- /dev/null +++ b/modules/apps/callbacks/types/events.go @@ -0,0 +1,102 @@ +package types + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" +) + +const ( + // EventTypeSourceCallback is the event type for a source callback + EventTypeSourceCallback = "ibc_src_callback" + // EventTypeDestinationCallback is the event type for a destination callback + EventTypeDestinationCallback = "ibc_dest_callback" + + // AttributeKeyCallbackType denotes the condition that the callback is executed on: + // "acknowledgement": the callback is executed on the acknowledgement of the packet + // "timeout": the callback is executed on the timeout of the packet + // "recv_packet": the callback is executed on the reception of the packet + AttributeKeyCallbackType = "callback_type" + // AttributeKeyCallbackAddress denotes the callback address + AttributeKeyCallbackAddress = "callback_address" + // AttributeKeyCallbackResult denotes the callback result: + // AttributeValueCallbackSuccess: the callback is successfully executed + // AttributeValueCallbackFailure: the callback has failed to execute + AttributeKeyCallbackResult = "callback_result" + // AttributeKeyCallbackError denotes the callback error message + // if no error is returned, then this key will not be included in the event + AttributeKeyCallbackError = "callback_error" + // AttributeKeyCallbackGasLimit denotes the custom gas limit for the callback execution + // if custom gas limit is not in effect, then this key will not be included in the event + AttributeKeyCallbackGasLimit = "callback_exec_gas_limit" + // AttributeKeyCallbackCommitGasLimit denotes the gas needed to commit the callback even + // if the callback execution fails due to out of gas. + AttributeKeyCallbackCommitGasLimit = "callback_commit_gas_limit" + // AttributeKeyCallbackSourcePortID denotes the source port ID of the packet + AttributeKeyCallbackSourcePortID = "packet_src_port" + // AttributeKeyCallbackSourceChannelID denotes the source channel ID of the packet + AttributeKeyCallbackSourceChannelID = "packet_src_channel" + // AttributeKeyCallbackDestPortID denotes the destination port ID of the packet + AttributeKeyCallbackDestPortID = "packet_dest_port" + // AttributeKeyCallbackDestChannelID denotes the destination channel ID of the packet + AttributeKeyCallbackDestChannelID = "packet_dest_channel" + // AttributeKeyCallbackSequence denotes the sequence of the packet + AttributeKeyCallbackSequence = "packet_sequence" + + // AttributeValueCallbackSuccess denotes that the callback is successfully executed + AttributeValueCallbackSuccess = "success" + // AttributeValueCallbackFailure denotes that the callback has failed to execute + AttributeValueCallbackFailure = "failure" +) + +// EmitCallbackEvent emits an event for a callback +func EmitCallbackEvent( + ctx sdk.Context, + packet ibcexported.PacketI, + callbackType CallbackType, + callbackData CallbackData, + err error, +) { + attributes := []sdk.Attribute{ + sdk.NewAttribute(sdk.AttributeKeyModule, ModuleName), + sdk.NewAttribute(AttributeKeyCallbackType, string(callbackType)), + sdk.NewAttribute(AttributeKeyCallbackAddress, callbackData.CallbackAddress), + sdk.NewAttribute(AttributeKeyCallbackGasLimit, fmt.Sprintf("%d", callbackData.ExecutionGasLimit)), + sdk.NewAttribute(AttributeKeyCallbackCommitGasLimit, fmt.Sprintf("%d", callbackData.CommitGasLimit)), + sdk.NewAttribute(AttributeKeyCallbackSequence, fmt.Sprintf("%d", packet.GetSequence())), + } + if err == nil { + attributes = append(attributes, sdk.NewAttribute(AttributeKeyCallbackResult, AttributeValueCallbackSuccess)) + } else { + attributes = append( + attributes, + sdk.NewAttribute(AttributeKeyCallbackError, err.Error()), + sdk.NewAttribute(AttributeKeyCallbackResult, AttributeValueCallbackFailure), + ) + } + + var eventType string + switch callbackType { + case CallbackTypeReceivePacket: + eventType = EventTypeDestinationCallback + attributes = append( + attributes, sdk.NewAttribute(AttributeKeyCallbackDestPortID, packet.GetDestPort()), + sdk.NewAttribute(AttributeKeyCallbackDestChannelID, packet.GetDestChannel()), + ) + default: + eventType = EventTypeSourceCallback + attributes = append( + attributes, sdk.NewAttribute(AttributeKeyCallbackSourcePortID, packet.GetSourcePort()), + sdk.NewAttribute(AttributeKeyCallbackSourceChannelID, packet.GetSourceChannel()), + ) + } + + ctx.EventManager().EmitEvent( + sdk.NewEvent( + eventType, + attributes..., + ), + ) +} diff --git a/modules/apps/callbacks/types/events_test.go b/modules/apps/callbacks/types/events_test.go new file mode 100644 index 00000000000..615f936ca5f --- /dev/null +++ b/modules/apps/callbacks/types/events_test.go @@ -0,0 +1,193 @@ +package types_test + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + ibctesting "github.com/cosmos/ibc-go/v7/testing" +) + +func (s *CallbacksTypesTestSuite) TestEvents() { + testCases := []struct { + name string + packet channeltypes.Packet + callbackType types.CallbackType + callbackData types.CallbackData + callbackError error + expEvents ibctesting.EventsMap + }{ + { + "success: ack callback", + channeltypes.NewPacket( + ibctesting.MockPacketData, 1, ibctesting.MockPort, ibctesting.FirstChannelID, + ibctesting.MockFeePort, ibctesting.InvalidID, clienttypes.NewHeight(1, 100), 0, + ), + types.CallbackTypeAcknowledgementPacket, + types.CallbackData{ + CallbackAddress: ibctesting.TestAccAddress, + ExecutionGasLimit: 100000, + CommitGasLimit: 200000, + }, + nil, + ibctesting.EventsMap{ + types.EventTypeSourceCallback: { + sdk.AttributeKeyModule: types.ModuleName, + types.AttributeKeyCallbackType: string(types.CallbackTypeAcknowledgementPacket), + types.AttributeKeyCallbackAddress: ibctesting.TestAccAddress, + types.AttributeKeyCallbackGasLimit: "100000", + types.AttributeKeyCallbackCommitGasLimit: "200000", + types.AttributeKeyCallbackSourcePortID: ibctesting.MockPort, + types.AttributeKeyCallbackSourceChannelID: ibctesting.FirstChannelID, + types.AttributeKeyCallbackSequence: "1", + types.AttributeKeyCallbackResult: types.AttributeValueCallbackSuccess, + }, + }, + }, + { + "success: send packet callback", + channeltypes.NewPacket( + ibctesting.MockPacketData, 1, ibctesting.MockPort, ibctesting.FirstChannelID, + ibctesting.MockFeePort, ibctesting.InvalidID, clienttypes.NewHeight(1, 100), 0, + ), + types.CallbackTypeSendPacket, + types.CallbackData{ + CallbackAddress: ibctesting.TestAccAddress, + ExecutionGasLimit: 100000, + CommitGasLimit: 200000, + }, + nil, + ibctesting.EventsMap{ + types.EventTypeSourceCallback: { + sdk.AttributeKeyModule: types.ModuleName, + types.AttributeKeyCallbackType: string(types.CallbackTypeSendPacket), + types.AttributeKeyCallbackAddress: ibctesting.TestAccAddress, + types.AttributeKeyCallbackGasLimit: "100000", + types.AttributeKeyCallbackCommitGasLimit: "200000", + types.AttributeKeyCallbackSourcePortID: ibctesting.MockPort, + types.AttributeKeyCallbackSourceChannelID: ibctesting.FirstChannelID, + types.AttributeKeyCallbackSequence: "1", + types.AttributeKeyCallbackResult: types.AttributeValueCallbackSuccess, + }, + }, + }, + { + "success: timeout callback", + channeltypes.NewPacket( + ibctesting.MockPacketData, 1, ibctesting.MockPort, ibctesting.FirstChannelID, + ibctesting.MockFeePort, ibctesting.InvalidID, clienttypes.NewHeight(1, 100), 0, + ), + types.CallbackTypeTimeoutPacket, + types.CallbackData{ + CallbackAddress: ibctesting.TestAccAddress, + ExecutionGasLimit: 100000, + CommitGasLimit: 200000, + }, + nil, + ibctesting.EventsMap{ + types.EventTypeSourceCallback: { + sdk.AttributeKeyModule: types.ModuleName, + types.AttributeKeyCallbackType: string(types.CallbackTypeTimeoutPacket), + types.AttributeKeyCallbackAddress: ibctesting.TestAccAddress, + types.AttributeKeyCallbackGasLimit: "100000", + types.AttributeKeyCallbackCommitGasLimit: "200000", + types.AttributeKeyCallbackSourcePortID: ibctesting.MockPort, + types.AttributeKeyCallbackSourceChannelID: ibctesting.FirstChannelID, + types.AttributeKeyCallbackSequence: "1", + types.AttributeKeyCallbackResult: types.AttributeValueCallbackSuccess, + }, + }, + }, + { + "success: receive packet callback", + channeltypes.NewPacket( + ibctesting.MockPacketData, 1, ibctesting.MockPort, ibctesting.FirstChannelID, + ibctesting.MockFeePort, ibctesting.InvalidID, clienttypes.NewHeight(1, 100), 0, + ), + types.CallbackTypeReceivePacket, + types.CallbackData{ + CallbackAddress: ibctesting.TestAccAddress, + ExecutionGasLimit: 100000, + CommitGasLimit: 200000, + }, + nil, + ibctesting.EventsMap{ + types.EventTypeDestinationCallback: { + sdk.AttributeKeyModule: types.ModuleName, + types.AttributeKeyCallbackType: string(types.CallbackTypeReceivePacket), + types.AttributeKeyCallbackAddress: ibctesting.TestAccAddress, + types.AttributeKeyCallbackGasLimit: "100000", + types.AttributeKeyCallbackCommitGasLimit: "200000", + types.AttributeKeyCallbackDestPortID: ibctesting.MockFeePort, + types.AttributeKeyCallbackDestChannelID: ibctesting.InvalidID, + types.AttributeKeyCallbackSequence: "1", + types.AttributeKeyCallbackResult: types.AttributeValueCallbackSuccess, + }, + }, + }, + { + "success: unknown callback", + channeltypes.NewPacket( + ibctesting.MockPacketData, 1, ibctesting.MockPort, ibctesting.FirstChannelID, + ibctesting.MockFeePort, ibctesting.InvalidID, clienttypes.NewHeight(1, 100), 0, + ), + "something", + types.CallbackData{ + CallbackAddress: ibctesting.TestAccAddress, + ExecutionGasLimit: 100000, + CommitGasLimit: 200000, + }, + nil, + ibctesting.EventsMap{ + types.EventTypeSourceCallback: { + sdk.AttributeKeyModule: types.ModuleName, + types.AttributeKeyCallbackType: "something", + types.AttributeKeyCallbackAddress: ibctesting.TestAccAddress, + types.AttributeKeyCallbackGasLimit: "100000", + types.AttributeKeyCallbackCommitGasLimit: "200000", + types.AttributeKeyCallbackSourcePortID: ibctesting.MockPort, + types.AttributeKeyCallbackSourceChannelID: ibctesting.FirstChannelID, + types.AttributeKeyCallbackSequence: "1", + types.AttributeKeyCallbackResult: types.AttributeValueCallbackSuccess, + }, + }, + }, + { + "failure: ack callback with error", + channeltypes.NewPacket( + ibctesting.MockPacketData, 1, ibctesting.MockPort, ibctesting.FirstChannelID, + ibctesting.MockFeePort, ibctesting.InvalidID, clienttypes.NewHeight(1, 100), 0, + ), + types.CallbackTypeAcknowledgementPacket, + types.CallbackData{ + CallbackAddress: ibctesting.TestAccAddress, + ExecutionGasLimit: 100000, + CommitGasLimit: 200000, + }, + types.ErrNotPacketDataProvider, + ibctesting.EventsMap{ + types.EventTypeSourceCallback: { + sdk.AttributeKeyModule: types.ModuleName, + types.AttributeKeyCallbackType: string(types.CallbackTypeAcknowledgementPacket), + types.AttributeKeyCallbackAddress: ibctesting.TestAccAddress, + types.AttributeKeyCallbackGasLimit: "100000", + types.AttributeKeyCallbackCommitGasLimit: "200000", + types.AttributeKeyCallbackSourcePortID: ibctesting.MockPort, + types.AttributeKeyCallbackSourceChannelID: ibctesting.FirstChannelID, + types.AttributeKeyCallbackSequence: "1", + types.AttributeKeyCallbackResult: types.AttributeValueCallbackFailure, + types.AttributeKeyCallbackError: types.ErrNotPacketDataProvider.Error(), + }, + }, + }, + } + + for _, tc := range testCases { + newCtx := sdk.Context{}.WithEventManager(sdk.NewEventManager()) + + types.EmitCallbackEvent(newCtx, tc.packet, tc.callbackType, tc.callbackData, tc.callbackError) + events := newCtx.EventManager().Events().ToABCIEvents() + ibctesting.AssertEvents(&s.Suite, tc.expEvents, events) + } +} diff --git a/modules/apps/callbacks/types/expected_keepers.go b/modules/apps/callbacks/types/expected_keepers.go new file mode 100644 index 00000000000..f21d3dbd923 --- /dev/null +++ b/modules/apps/callbacks/types/expected_keepers.go @@ -0,0 +1,64 @@ +package types + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" +) + +// ContractKeeper defines the entry points exposed to the VM module which invokes a smart contract +type ContractKeeper interface { + // IBCSendPacketCallback is called in the source chain when a PacketSend is executed. The + // packetSenderAddress is determined by the underlying module, and may be empty if the sender is + // unknown or undefined. The contract is expected to handle the callback within the user defined + // gas limit, and handle any errors, or panics gracefully. + // If an error is returned, the transaction will be reverted by the callbacks middleware, and the + // packet will not be sent. + IBCSendPacketCallback( + ctx sdk.Context, + sourcePort string, + sourceChannel string, + timeoutHeight clienttypes.Height, + timeoutTimestamp uint64, + packetData []byte, + contractAddress, + packetSenderAddress string, + ) error + // IBCOnAcknowledgementPacketCallback is called in the source chain when a packet acknowledgement + // is received. The packetSenderAddress is determined by the underlying module, and may be empty if + // the sender is unknown or undefined. The contract is expected to handle the callback within the + // user defined gas limit, and handle any errors, or panics gracefully. + // If an error is returned, state will be reverted by the callbacks middleware. + IBCOnAcknowledgementPacketCallback( + ctx sdk.Context, + packet channeltypes.Packet, + acknowledgement []byte, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, + ) error + // IBCOnTimeoutPacketCallback is called in the source chain when a packet is not received before + // the timeout height. The packetSenderAddress is determined by the underlying module, and may be + // empty if the sender is unknown or undefined. The contract is expected to handle the callback + // within the user defined gas limit, and handle any error, out of gas, or panics gracefully. + // If an error is returned, state will be reverted by the callbacks middleware. + IBCOnTimeoutPacketCallback( + ctx sdk.Context, + packet channeltypes.Packet, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, + ) error + // IBCReceivePacketCallback is called in the destination chain when a packet acknowledgement is written. + // The contract is expected to handle the callback within the user defined gas limit, and handle any errors, + // out of gas, or panics gracefully. + // If an error is returned, state will be reverted by the callbacks middleware. + IBCReceivePacketCallback( + ctx sdk.Context, + packet ibcexported.PacketI, + ack ibcexported.Acknowledgement, + contractAddress string, + ) error +} diff --git a/modules/apps/callbacks/types/export_test.go b/modules/apps/callbacks/types/export_test.go new file mode 100644 index 00000000000..facb96952f6 --- /dev/null +++ b/modules/apps/callbacks/types/export_test.go @@ -0,0 +1,29 @@ +package types + +import ( + porttypes "github.com/cosmos/ibc-go/v7/modules/core/05-port/types" + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" +) + +/* + This file is to allow for unexported functions to be accessible to the testing package. +*/ + +// GetCallbackData is a wrapper around getCallbackData to allow the function to be directly called in tests. +func GetCallbackData( + packetDataUnmarshaler porttypes.PacketDataUnmarshaler, + packet ibcexported.PacketI, remainingGas uint64, + maxGas uint64, callbackKey string, +) (CallbackData, error) { + return getCallbackData(packetDataUnmarshaler, packet, remainingGas, maxGas, callbackKey) +} + +// GetCallbackAddress is a wrapper around getCallbackAddress to allow the function to be directly called in tests. +func GetCallbackAddress(callbackData map[string]interface{}) string { + return getCallbackAddress(callbackData) +} + +// GetUserDefinedGasLimit is a wrapper around getUserDefinedGasLimit to allow the function to be directly called in tests. +func GetUserDefinedGasLimit(callbackData map[string]interface{}) uint64 { + return getUserDefinedGasLimit(callbackData) +} diff --git a/modules/apps/callbacks/types/keys.go b/modules/apps/callbacks/types/keys.go new file mode 100644 index 00000000000..d07613cc466 --- /dev/null +++ b/modules/apps/callbacks/types/keys.go @@ -0,0 +1,31 @@ +package types + +type CallbackType string + +const ( + ModuleName = "ibccallbacks" + + CallbackTypeSendPacket CallbackType = "send_packet" + CallbackTypeAcknowledgementPacket CallbackType = "acknowledgement_packet" + CallbackTypeTimeoutPacket CallbackType = "timeout_packet" + CallbackTypeReceivePacket CallbackType = "receive_packet" + + // Source callback packet data is set inside the underlying packet data using the this key. + // ICS20 and ICS27 will store the callback packet data in the memo field as a json object. + // The expected format is as follows: + // {"src_callback": { ... }} + SourceCallbackKey = "src_callback" + // Destination callback packet data is set inside the underlying packet data using the this key. + // ICS20 and ICS27 will store the callback packet data in the memo field as a json object. + // The expected format is as follows: + // {"dest_callback": { ... }} + DestinationCallbackKey = "dest_callback" + // Callbacks' packet data is expected to contain the callback address under this key. + // The expected format for ICS20 and ICS27 memo field is as follows: + // { "{callbackKey}": { "address": {stringCallbackAddress}} + CallbackAddressKey = "address" + // Callbacks' packet data is expected to specify the user defined gas limit under this key. + // The expected format for ICS20 and ICS27 memo field is as follows: + // { "{callbackKey}": { ... , "gas_limit": {stringForCallback} } + UserDefinedGasLimitKey = "gas_limit" +) diff --git a/modules/apps/callbacks/types/types_test.go b/modules/apps/callbacks/types/types_test.go new file mode 100644 index 00000000000..ef21db59989 --- /dev/null +++ b/modules/apps/callbacks/types/types_test.go @@ -0,0 +1,28 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + ibctesting "github.com/cosmos/ibc-go/v7/testing" +) + +// CallbacksTestSuite defines the needed instances and methods to test callbacks +type CallbacksTypesTestSuite struct { + suite.Suite + + coord *ibctesting.Coordinator + + chain *ibctesting.TestChain +} + +// SetupTest creates a coordinator with 1 test chain. +func (s *CallbacksTypesTestSuite) SetupSuite() { + s.coord = ibctesting.NewCoordinator(s.T(), 1) + s.chain = s.coord.GetChain(ibctesting.GetChainID(1)) +} + +func TestCallbacksTypesTestSuite(t *testing.T) { + suite.Run(t, new(CallbacksTypesTestSuite)) +} diff --git a/modules/apps/transfer/keeper/keeper_test.go b/modules/apps/transfer/keeper/keeper_test.go index 739c18fdf5e..04ea5100ada 100644 --- a/modules/apps/transfer/keeper/keeper_test.go +++ b/modules/apps/transfer/keeper/keeper_test.go @@ -13,7 +13,6 @@ import ( storetypes "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" - ibcfeekeeper "github.com/cosmos/ibc-go/v7/modules/apps/29-fee/keeper" "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" channelkeeper "github.com/cosmos/ibc-go/v7/modules/core/04-channel/keeper" ibctesting "github.com/cosmos/ibc-go/v7/testing" @@ -269,11 +268,9 @@ func (suite *KeeperTestSuite) TestUnsetParams() { func (suite *KeeperTestSuite) TestWithICS4Wrapper() { suite.SetupTest() - // test if the ics4 wrapper is the fee keeper initially + // test if the ics4 wrapper is the channel keeper initially ics4Wrapper := suite.chainA.GetSimApp().TransferKeeper.GetICS4Wrapper() - _, isFeeKeeper := ics4Wrapper.(ibcfeekeeper.Keeper) - suite.Require().True(isFeeKeeper) _, isChannelKeeper := ics4Wrapper.(channelkeeper.Keeper) suite.Require().False(isChannelKeeper) @@ -283,6 +280,4 @@ func (suite *KeeperTestSuite) TestWithICS4Wrapper() { _, isChannelKeeper = ics4Wrapper.(channelkeeper.Keeper) suite.Require().True(isChannelKeeper) - _, isFeeKeeper = ics4Wrapper.(ibcfeekeeper.Keeper) - suite.Require().False(isFeeKeeper) } diff --git a/testing/mock/contract_keeper.go b/testing/mock/contract_keeper.go new file mode 100644 index 00000000000..d030de2393b --- /dev/null +++ b/testing/mock/contract_keeper.go @@ -0,0 +1,154 @@ +package mock + +import ( + "fmt" + + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + + callbacktypes "github.com/cosmos/ibc-go/v7/modules/apps/callbacks/types" + clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" + ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported" +) + +// MockKeeper implements callbacktypes.ContractKeeper +var _ callbacktypes.ContractKeeper = (*ContractKeeper)(nil) + +// This is a mock contract keeper used for testing. It is not wired up to any modules. +// It implements the interface functions expected by the ibccallbacks middleware +// so that it can be tested with simapp. The keeper is responsible for tracking +// two metrics: +// - number of callbacks called per callback type +// - stateful entry attempts +// +// The counter for callbacks allows us to ensure the correct callbacks were routed to +// and the stateful entries allows us to track state reversals or reverted state upon +// contract execution failure or out of gas errors. +type ContractKeeper struct { + key storetypes.StoreKey + + Counters map[callbacktypes.CallbackType]int +} + +// SetStateEntryCounter sets state entry counter. The number of stateful +// entries is tracked as a uint8. This function is used to test state reversals. +func (k ContractKeeper) SetStateEntryCounter(ctx sdk.Context, count uint8) { + store := ctx.KVStore(k.key) + store.Set([]byte(StatefulCounterKey), []byte{count}) +} + +// GetStateEntryCounter returns the state entry counter stored in state. +func (k ContractKeeper) GetStateEntryCounter(ctx sdk.Context) uint8 { + store := ctx.KVStore(k.key) + bz := store.Get([]byte(StatefulCounterKey)) + if bz == nil { + return 0 + } + return bz[0] +} + +// IncrementStatefulCounter increments the stateful callback counter in state. +func (k ContractKeeper) IncrementStateEntryCounter(ctx sdk.Context) { + count := k.GetStateEntryCounter(ctx) + k.SetStateEntryCounter(ctx, count+1) +} + +// NewKeeper creates a new mock ContractKeeper. +func NewContractKeeper(key storetypes.StoreKey) ContractKeeper { + return ContractKeeper{ + key: key, + Counters: make(map[callbacktypes.CallbackType]int), + } +} + +// IBCPacketSendCallback returns nil if the gas meter has greater than +// or equal to 500_000 gas remaining. +// This function oog panics if the gas remaining is less than 500_000. +// This function errors if the authAddress is MockCallbackUnauthorizedAddress. +func (k ContractKeeper) IBCSendPacketCallback( + ctx sdk.Context, + sourcePort string, + sourceChannel string, + timeoutHeight clienttypes.Height, + timeoutTimestamp uint64, + packetData []byte, + contractAddress, + packetSenderAddress string, +) error { + return k.processMockCallback(ctx, callbacktypes.CallbackTypeSendPacket, packetSenderAddress) +} + +// IBCOnAcknowledgementPacketCallback returns nil if the gas meter has greater than +// or equal to 500_000 gas remaining. +// This function oog panics if the gas remaining is less than 500_000. +// This function errors if the authAddress is MockCallbackUnauthorizedAddress. +func (k ContractKeeper) IBCOnAcknowledgementPacketCallback( + ctx sdk.Context, + packet channeltypes.Packet, + acknowledgement []byte, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, +) error { + return k.processMockCallback(ctx, callbacktypes.CallbackTypeAcknowledgementPacket, packetSenderAddress) +} + +// IBCOnTimeoutPacketCallback returns nil if the gas meter has greater than +// or equal to 500_000 gas remaining. +// This function oog panics if the gas remaining is less than 500_000. +// This function errors if the authAddress is MockCallbackUnauthorizedAddress. +func (k ContractKeeper) IBCOnTimeoutPacketCallback( + ctx sdk.Context, + packet channeltypes.Packet, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, +) error { + return k.processMockCallback(ctx, callbacktypes.CallbackTypeTimeoutPacket, packetSenderAddress) +} + +// IBCReceivePacketCallback returns nil if the gas meter has greater than +// or equal to 500_000 gas remaining. +// This function oog panics if the gas remaining is less than 500_000. +// This function errors if the authAddress is MockCallbackUnauthorizedAddress. +func (k ContractKeeper) IBCReceivePacketCallback( + ctx sdk.Context, + packet ibcexported.PacketI, + ack ibcexported.Acknowledgement, + contractAddress string, +) error { + return k.processMockCallback(ctx, callbacktypes.CallbackTypeReceivePacket, "") +} + +// processMockCallback returns nil if the gas meter has greater than or equal to 500_000 gas remaining. +// This function oog panics if the gas remaining is less than 500_000. +// This function errors if the authAddress is MockCallbackUnauthorizedAddress. +func (k ContractKeeper) processMockCallback( + ctx sdk.Context, + callbackType callbacktypes.CallbackType, + authAddress string, +) error { + gasRemaining := ctx.GasMeter().GasRemaining() + + // increment stateful entries, if the callbacks module handler + // reverts state, we can check by querying for the counter + // currently stored. + k.IncrementStateEntryCounter(ctx) + + // increment callback execution attempts + k.Counters[callbackType]++ + + if gasRemaining < 500000 { + // consume gas will panic since we attempt to consume 500_000 gas, for tests + ctx.GasMeter().ConsumeGas(500000, fmt.Sprintf("mock %s callback panic", callbackType)) + } + + if authAddress == MockCallbackUnauthorizedAddress { + ctx.GasMeter().ConsumeGas(500000, fmt.Sprintf("mock %s callback unauthorized", callbackType)) + return MockApplicationCallbackError + } + + ctx.GasMeter().ConsumeGas(500000, fmt.Sprintf("mock %s callback success", callbackType)) + return nil +} diff --git a/testing/mock/mock.go b/testing/mock/mock.go index 26c26247b8f..9913e6ead70 100644 --- a/testing/mock/mock.go +++ b/testing/mock/mock.go @@ -33,6 +33,8 @@ const ( ) var ( + StatefulCounterKey = "stateful-callback-counter" + MockAcknowledgement = channeltypes.NewResultAcknowledgement([]byte("mock acknowledgement")) MockFailAcknowledgement = channeltypes.NewErrorAcknowledgement(fmt.Errorf("mock failed acknowledgement")) MockPacketData = []byte("mock packet data") @@ -41,12 +43,13 @@ var ( MockRecvCanaryCapabilityName = "mock receive canary capability name" MockAckCanaryCapabilityName = "mock acknowledgement canary capability name" MockTimeoutCanaryCapabilityName = "mock timeout canary capability name" + MockCallbackUnauthorizedAddress = "cosmos15ulrf36d4wdtrtqzkgaan9ylwuhs7k7qz753uk" // MockApplicationCallbackError should be returned when an application callback should fail. It is possible to // test that this error was returned using ErrorIs. MockApplicationCallbackError error = &applicationCallbackError{} ) -var _ porttypes.IBCModule = IBCModule{} +var _ porttypes.IBCModule = (*IBCModule)(nil) // Expected Interface // PortKeeper defines the expected IBC port keeper diff --git a/testing/simapp/app.go b/testing/simapp/app.go index b267158af08..9567d04018b 100644 --- a/testing/simapp/app.go +++ b/testing/simapp/app.go @@ -112,6 +112,7 @@ import ( ibcfee "github.com/cosmos/ibc-go/v7/modules/apps/29-fee" ibcfeekeeper "github.com/cosmos/ibc-go/v7/modules/apps/29-fee/keeper" ibcfeetypes "github.com/cosmos/ibc-go/v7/modules/apps/29-fee/types" + ibccallbacks "github.com/cosmos/ibc-go/v7/modules/apps/callbacks" transfer "github.com/cosmos/ibc-go/v7/modules/apps/transfer" ibctransferkeeper "github.com/cosmos/ibc-go/v7/modules/apps/transfer/keeper" ibctransfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" @@ -247,6 +248,9 @@ type SimApp struct { ScopedIBCMockKeeper capabilitykeeper.ScopedKeeper ScopedICAMockKeeper capabilitykeeper.ScopedKeeper + // mock contract keeper used for testing + MockContractKeeper ibcmock.ContractKeeper + // make IBC modules public for test purposes // these modules are never directly routed to by the IBC Router ICAAuthModule ibcmock.IBCModule @@ -435,6 +439,10 @@ func NewSimApp( appCodec, keys[ibcexported.StoreKey], app.GetSubspace(ibcexported.ModuleName), app.StakingKeeper, app.UpgradeKeeper, scopedIBCKeeper, authtypes.NewModuleAddress(govtypes.ModuleName).String(), ) + // NOTE: The mock ContractKeeper is only created for testing. + // Real applications should not use the mock ContractKeeper + app.MockContractKeeper = ibcmock.NewContractKeeper(memKeys[ibcmock.MemStoreKey]) + // register the proposal types govRouter := govv1beta1.NewRouter() govRouter.AddRoute(govtypes.RouterKey, govv1beta1.ProposalHandler). @@ -491,9 +499,11 @@ func NewSimApp( ibcRouter := porttypes.NewRouter() // Middleware Stacks + maxCallbackGas := uint64(1_000_000) // Create Transfer Keeper and pass IBCFeeKeeper as expected Channel and PortKeeper // since fee middleware will wrap the IBCKeeper for underlying application. + // NOTE: the Transfer Keeper's ICS4Wrapper can later be replaced. app.TransferKeeper = ibctransferkeeper.NewKeeper( appCodec, keys[ibctransfertypes.StoreKey], app.GetSubspace(ibctransfertypes.ModuleName), app.IBCFeeKeeper, // ISC4 Wrapper: fee IBC middleware @@ -515,12 +525,13 @@ func NewSimApp( // Create Transfer Stack // SendPacket, since it is originating from the application to core IBC: - // transferKeeper.SendPacket -> fee.SendPacket -> channel.SendPacket + // transferKeeper.SendPacket -> fee.SendPacket -> callbacks.SendPacket -> channel.SendPacket // RecvPacket, message that originates from core IBC and goes down to app, the flow is the other way - // channel.RecvPacket -> fee.OnRecvPacket -> transfer.OnRecvPacket + // channel.RecvPacket -> callbacks.OnRecvPacket -> fee.OnRecvPacket -> transfer.OnRecvPacket // transfer stack contains (from top to bottom): + // - IBC Callbacks Middleware // - IBC Fee Middleware // - Transfer @@ -528,13 +539,16 @@ func NewSimApp( var transferStack porttypes.IBCModule transferStack = transfer.NewIBCModule(app.TransferKeeper) transferStack = ibcfee.NewIBCMiddleware(transferStack, app.IBCFeeKeeper) + transferStack = ibccallbacks.NewIBCMiddleware(transferStack, app.IBCFeeKeeper, app.MockContractKeeper, maxCallbackGas) + // Since the callbacks middleware itself is an ics4wrapper, it needs to be passed to the transfer keeper + app.TransferKeeper.WithICS4Wrapper(transferStack.(porttypes.Middleware)) // Add transfer stack to IBC Router ibcRouter.AddRoute(ibctransfertypes.ModuleName, transferStack) // Create Interchain Accounts Stack // SendPacket, since it is originating from the application to core IBC: - // icaAuthModuleKeeper.SendTx -> icaController.SendPacket -> fee.SendPacket -> channel.SendPacket + // icaAuthModuleKeeper.SendTx -> icaController.SendPacket -> fee.SendPacket -> callbacks.SendPacket -> channel.SendPacket // initialize ICA module with mock module as the authentication module on the controller side var icaControllerStack porttypes.IBCModule @@ -542,9 +556,12 @@ func NewSimApp( app.ICAAuthModule = icaControllerStack.(ibcmock.IBCModule) icaControllerStack = icacontroller.NewIBCMiddleware(icaControllerStack, app.ICAControllerKeeper) icaControllerStack = ibcfee.NewIBCMiddleware(icaControllerStack, app.IBCFeeKeeper) + icaControllerStack = ibccallbacks.NewIBCMiddleware(icaControllerStack, app.IBCFeeKeeper, app.MockContractKeeper, maxCallbackGas) + // Since the callbacks middleware itself is an ics4wrapper, it needs to be passed to the ica controller keeper + app.ICAControllerKeeper.WithICS4Wrapper(icaControllerStack.(porttypes.Middleware)) // RecvPacket, message that originates from core IBC and goes down to app, the flow is: - // channel.RecvPacket -> fee.OnRecvPacket -> icaHost.OnRecvPacket + // channel.RecvPacket -> callbacks.OnRecvPacket -> fee.OnRecvPacket -> icaHost.OnRecvPacket var icaHostStack porttypes.IBCModule icaHostStack = icahost.NewIBCModule(app.ICAHostKeeper) @@ -572,7 +589,8 @@ func NewSimApp( // create fee wrapped mock module feeMockModule := ibcmock.NewIBCModule(&mockModule, ibcmock.NewIBCApp(MockFeePort, scopedFeeMockKeeper)) app.FeeMockModule = feeMockModule - feeWithMockModule := ibcfee.NewIBCMiddleware(feeMockModule, app.IBCFeeKeeper) + var feeWithMockModule porttypes.Middleware = ibcfee.NewIBCMiddleware(feeMockModule, app.IBCFeeKeeper) + feeWithMockModule = ibccallbacks.NewIBCMiddleware(feeWithMockModule, app.IBCFeeKeeper, app.MockContractKeeper, maxCallbackGas) ibcRouter.AddRoute(MockFeePort, feeWithMockModule) // Seal the IBC Router