diff --git a/modules/core/04-channel/types/upgrade.go b/modules/core/04-channel/types/upgrade.go index 22bdb95e595..ea3e8508389 100644 --- a/modules/core/04-channel/types/upgrade.go +++ b/modules/core/04-channel/types/upgrade.go @@ -83,13 +83,21 @@ func (u *UpgradeError) Error() string { return u.err.Error() } -// Is returns true if the underlying error is of the given err type. -func (u *UpgradeError) Is(err error) bool { - return errors.Is(u.err, err) +// Is returns true if the of the provided error is an upgrade error. +func (*UpgradeError) Is(err error) bool { + _, ok := err.(*UpgradeError) + return ok } -// Unwrap returns the base error that caused the upgrade to fail. +// Unwrap returns the next error in the error chain. +// If there is no next error, Unwrap returns nil. func (u *UpgradeError) Unwrap() error { + return u.err +} + +// Cause implements the sdk error interface which uses this function to unwrap the error in various functions such as `wrappedError.Is()`. +// Cause returns the underlying error which caused the upgrade to fail. +func (u *UpgradeError) Cause() error { baseError := u.err for { if err := errors.Unwrap(baseError); err != nil { @@ -100,12 +108,6 @@ func (u *UpgradeError) Unwrap() error { } } -// Cause implements the sdk error interface which uses this function to unwrap the error in various functions such as `wrappedError.Is()`. -// Cause returns the underlying error which caused the upgrade to fail. -func (u *UpgradeError) Cause() error { - return u.err -} - // GetErrorReceipt returns an error receipt with the code from the underlying error type stripped. func (u *UpgradeError) GetErrorReceipt() ErrorReceipt { // restoreErrorString defines a string constant included in error receipts. @@ -122,14 +124,5 @@ func (u *UpgradeError) GetErrorReceipt() ErrorReceipt { // IsUpgradeError returns true if err is of type UpgradeError or contained // in the error chain of err and false otherwise. func IsUpgradeError(err error) bool { - for { - _, ok := err.(*UpgradeError) - if ok { - return true - } - - if err = errors.Unwrap(err); err == nil { - return false - } - } + return errors.Is(err, &UpgradeError{}) } diff --git a/modules/core/04-channel/types/upgrade_test.go b/modules/core/04-channel/types/upgrade_test.go index fdb8a90cec5..b1bb8315d46 100644 --- a/modules/core/04-channel/types/upgrade_test.go +++ b/modules/core/04-channel/types/upgrade_test.go @@ -159,18 +159,33 @@ func (suite *TypesTestSuite) TestGetErrorReceipt() { suite.Require().Equal(upgradeError2.GetErrorReceipt().Message, upgradeError.GetErrorReceipt().Message) } -// TestUpgradeErrorUnwrap tests that the underlying error is not modified when Unwrap is called. +// TestUpgradeErrorUnwrap tests that the underlying error is returned by Unwrap. func (suite *TypesTestSuite) TestUpgradeErrorUnwrap() { - baseUnderlyingError := errorsmod.Wrap(types.ErrInvalidChannel, "base error") - wrappedErr := errorsmod.Wrap(baseUnderlyingError, "wrapped error") - upgradeError := types.NewUpgradeError(1, wrappedErr) - - originalUpgradeError := upgradeError.Error() - unWrapped := errors.Unwrap(upgradeError) - postUnwrapUpgradeError := upgradeError.Error() + testCases := []struct { + msg string + upgradeError *types.UpgradeError + expError error + }{ + { + msg: "no underlying error", + upgradeError: types.NewUpgradeError(1, nil), + expError: nil, + }, + { + msg: "underlying error", + upgradeError: types.NewUpgradeError(1, types.ErrInvalidUpgrade), + expError: types.ErrInvalidUpgrade, + }, + } - suite.Require().Equal(types.ErrInvalidChannel, unWrapped, "unwrapped error was not equal to base underlying error") - suite.Require().Equal(originalUpgradeError, postUnwrapUpgradeError, "original error was modified when unwrapped") + for _, tc := range testCases { + tc := tc + suite.Run(tc.msg, func() { + upgradeError := tc.upgradeError + err := upgradeError.Unwrap() + suite.Require().Equal(tc.expError, err) + }) + } } func (suite *TypesTestSuite) TestIsUpgradeError() {