diff --git a/modules/core/04-channel/keeper/keeper.go b/modules/core/04-channel/keeper/keeper.go index a65d69266c1..5dce1aff54d 100644 --- a/modules/core/04-channel/keeper/keeper.go +++ b/modules/core/04-channel/keeper/keeper.go @@ -514,6 +514,12 @@ func (k Keeper) setUpgradeErrorReceipt(ctx sdk.Context, portID, channelID string store.Set(host.ChannelUpgradeErrorKey(portID, channelID), bz) } +// hasUpgrade returns true if a proposed upgrade exists in store +func (k Keeper) hasUpgrade(ctx sdk.Context, portID, channelID string) bool { + store := ctx.KVStore(k.storeKey) + return store.Has(host.ChannelUpgradeKey(portID, channelID)) +} + // GetUpgrade returns the proposed upgrade for the provided port and channel identifiers. func (k Keeper) GetUpgrade(ctx sdk.Context, portID, channelID string) (types.Upgrade, bool) { store := ctx.KVStore(k.storeKey) diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index e4c29a34560..b6fbcf61b01 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -53,6 +53,11 @@ func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID strin panic(fmt.Errorf("could not find existing channel when updating channel state in successful ChanUpgradeInit step, channelID: %s, portID: %s", channelID, portID)) } + if k.hasUpgrade(ctx, portID, channelID) { + // invalidating previous upgrade + k.WriteErrorReceipt(ctx, portID, channelID, types.NewUpgradeError(channel.UpgradeSequence, types.ErrInvalidUpgrade)) + } + channel.UpgradeSequence++ upgrade.Fields.Version = upgradeVersion diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index fc29328475f..ffb90ddfac2 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -671,6 +671,14 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { }, nil, }, + { + "failure if initializing chain reinitializes before ACK", + func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + }, + commitmenttypes.ErrInvalidProof, // sequences are out of sync + }, { "channel not found", func() { @@ -888,6 +896,76 @@ func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() { } } +func (suite *KeeperTestSuite) TestChanUpgrade_ReinitializedBeforeAck() { + var path *ibctesting.Path + suite.Run("setup path", func() { + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + }) + + suite.Run("chainA upgrade init", func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(uint64(1), channel.UpgradeSequence) + }) + + suite.Run("chainB upgrade try", func() { + err := path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + }) + + suite.Run("chainA upgrade init reinitialized after ack", func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(uint64(2), channel.UpgradeSequence) + }) + + suite.Run("chan upgrade ack fails", func() { + err := path.EndpointA.ChanUpgradeAck() + suite.Require().Error(err) + }) + + suite.Run("chainB upgrade cancel", func() { + err := path.EndpointB.ChanUpgradeCancel() + suite.Require().NoError(err) + }) + + suite.Run("upgrade handshake succeeds on new upgrade attempt", func() { + err := path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeConfirm() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeOpen() + suite.Require().NoError(err) + }) + + suite.Run("assert successful upgrade expected channel state", func() { + channelA := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channelA.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelA.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, path.EndpointB.GetChannel().Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(2), channelA.UpgradeSequence, "upgrade sequence should be incremented") + + channelB := path.EndpointB.GetChannel() + suite.Require().Equal(types.OPEN, channelB.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(2), channelB.UpgradeSequence, "upgrade sequence should be incremented") + }) +} + func (suite *KeeperTestSuite) TestChanUpgradeConfirm() { var ( path *ibctesting.Path