Skip to content

Commit

Permalink
fix: Write ErrorReceipt for previous upgrade on Reinitialization (cos…
Browse files Browse the repository at this point in the history
…mos#5732)

* fix reinitialization

* write error receipt on reinitialization

* gofumpt

* switch to Has instead of Get

* imp: HasUpgrade -> hasUpgrade

---------

Co-authored-by: colin axnér <25233464+colin-axner@users.noreply.github.com>
  • Loading branch information
AdityaSripal and colin-axner committed Jan 25, 2024
1 parent bddbe49 commit 574a639
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
6 changes: 6 additions & 0 deletions modules/core/04-channel/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ func (k Keeper) setUpgradeErrorReceipt(ctx sdk.Context, portID, channelID string
store.Set(host.ChannelUpgradeErrorKey(portID, channelID), bz)
}

// hasUpgrade returns true if a proposed upgrade exists in store
func (k Keeper) hasUpgrade(ctx sdk.Context, portID, channelID string) bool {
store := ctx.KVStore(k.storeKey)
return store.Has(host.ChannelUpgradeKey(portID, channelID))
}

// GetUpgrade returns the proposed upgrade for the provided port and channel identifiers.
func (k Keeper) GetUpgrade(ctx sdk.Context, portID, channelID string) (types.Upgrade, bool) {
store := ctx.KVStore(k.storeKey)
Expand Down
5 changes: 5 additions & 0 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ func (k Keeper) WriteUpgradeInitChannel(ctx sdk.Context, portID, channelID strin
panic(fmt.Errorf("could not find existing channel when updating channel state in successful ChanUpgradeInit step, channelID: %s, portID: %s", channelID, portID))
}

if k.hasUpgrade(ctx, portID, channelID) {
// invalidating previous upgrade
k.WriteErrorReceipt(ctx, portID, channelID, types.NewUpgradeError(channel.UpgradeSequence, types.ErrInvalidUpgrade))
}

channel.UpgradeSequence++

upgrade.Fields.Version = upgradeVersion
Expand Down
78 changes: 78 additions & 0 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,14 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
},
nil,
},
{
"failure if initializing chain reinitializes before ACK",
func() {
err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)
},
commitmenttypes.ErrInvalidProof, // sequences are out of sync
},
{
"channel not found",
func() {
Expand Down Expand Up @@ -888,6 +896,76 @@ func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() {
}
}

func (suite *KeeperTestSuite) TestChanUpgrade_ReinitializedBeforeAck() {
var path *ibctesting.Path
suite.Run("setup path", func() {
path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
})

suite.Run("chainA upgrade init", func() {
err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

channel := path.EndpointA.GetChannel()
suite.Require().Equal(uint64(1), channel.UpgradeSequence)
})

suite.Run("chainB upgrade try", func() {
err := path.EndpointB.ChanUpgradeTry()
suite.Require().NoError(err)
})

suite.Run("chainA upgrade init reinitialized after ack", func() {
err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

channel := path.EndpointA.GetChannel()
suite.Require().Equal(uint64(2), channel.UpgradeSequence)
})

suite.Run("chan upgrade ack fails", func() {
err := path.EndpointA.ChanUpgradeAck()
suite.Require().Error(err)
})

suite.Run("chainB upgrade cancel", func() {
err := path.EndpointB.ChanUpgradeCancel()
suite.Require().NoError(err)
})

suite.Run("upgrade handshake succeeds on new upgrade attempt", func() {
err := path.EndpointB.ChanUpgradeTry()
suite.Require().NoError(err)

err = path.EndpointA.ChanUpgradeAck()
suite.Require().NoError(err)

err = path.EndpointB.ChanUpgradeConfirm()
suite.Require().NoError(err)

err = path.EndpointA.ChanUpgradeOpen()
suite.Require().NoError(err)
})

suite.Run("assert successful upgrade expected channel state", func() {
channelA := path.EndpointA.GetChannel()
suite.Require().Equal(types.OPEN, channelA.State, "channel should be in OPEN state")
suite.Require().Equal(mock.UpgradeVersion, channelA.Version, "version should be correctly upgraded")
suite.Require().Equal(mock.UpgradeVersion, path.EndpointB.GetChannel().Version, "version should be correctly upgraded")
suite.Require().Equal(uint64(2), channelA.UpgradeSequence, "upgrade sequence should be incremented")

channelB := path.EndpointB.GetChannel()
suite.Require().Equal(types.OPEN, channelB.State, "channel should be in OPEN state")
suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded")
suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded")
suite.Require().Equal(uint64(2), channelB.UpgradeSequence, "upgrade sequence should be incremented")
})
}

func (suite *KeeperTestSuite) TestChanUpgradeConfirm() {
var (
path *ibctesting.Path
Expand Down

0 comments on commit 574a639

Please sign in to comment.