Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support self calling contract on instantiation #300

Merged
merged 3 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
distr.NewAppModule(appCodec, app.distrKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
staking.NewAppModule(appCodec, app.stakingKeeper, app.accountKeeper, app.bankKeeper),
upgrade.NewAppModule(app.upgradeKeeper),
wasm.NewAppModule(app.wasmKeeper),
wasm.NewAppModule(&app.wasmKeeper),
evidence.NewAppModule(app.evidenceKeeper),
ibc.NewAppModule(app.ibcKeeper),
params.NewAppModule(app.paramsKeeper),
Expand Down Expand Up @@ -472,7 +472,7 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
distr.NewAppModule(appCodec, app.distrKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
slashing.NewAppModule(appCodec, app.slashingKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
params.NewAppModule(app.paramsKeeper),
wasm.NewAppModule(app.wasmKeeper),
wasm.NewAppModule(&app.wasmKeeper),
evidence.NewAppModule(app.evidenceKeeper),
ibc.NewAppModule(app.ibcKeeper),
transferModule,
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ func TestInitGenesis(t *testing.T) {
})

// export into genstate
genState := ExportGenesis(data.ctx, data.keeper)
genState := ExportGenesis(data.ctx, &data.keeper)

// create new app to import genstate into
newData := setupTest(t)
q2 := newData.module.LegacyQuerierHandler(nil)

// initialize new app with genstate
InitGenesis(newData.ctx, newData.keeper, *genState)
InitGenesis(newData.ctx, &newData.keeper, *genState)

// run same checks again on newdata, to make sure it was reinitialized correctly
assertCodeList(t, q2, newData.ctx, 1)
Expand Down
14 changes: 7 additions & 7 deletions x/wasm/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

// NewHandler returns a handler for "bank" type messages.
func NewHandler(k Keeper) sdk.Handler {
func NewHandler(k *Keeper) sdk.Handler {
return func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())

Expand Down Expand Up @@ -47,7 +47,7 @@ func filteredMessageEvents(manager *sdk.EventManager) []abci.Event {
return res
}

func handleStoreCode(ctx sdk.Context, k Keeper, msg *MsgStoreCode) (*sdk.Result, error) {
func handleStoreCode(ctx sdk.Context, k *Keeper, msg *MsgStoreCode) (*sdk.Result, error) {
err := msg.ValidateBasic()
if err != nil {
return nil, err
Expand All @@ -73,7 +73,7 @@ func handleStoreCode(ctx sdk.Context, k Keeper, msg *MsgStoreCode) (*sdk.Result,
}, nil
}

func handleInstantiate(ctx sdk.Context, k Keeper, msg *MsgInstantiateContract) (*sdk.Result, error) {
func handleInstantiate(ctx sdk.Context, k *Keeper, msg *MsgInstantiateContract) (*sdk.Result, error) {
contractAddr, err := k.Instantiate(ctx, msg.CodeID, msg.Sender, msg.Admin, msg.InitMsg, msg.Label, msg.InitFunds)
if err != nil {
return nil, err
Expand All @@ -95,7 +95,7 @@ func handleInstantiate(ctx sdk.Context, k Keeper, msg *MsgInstantiateContract) (
}, nil
}

func handleExecute(ctx sdk.Context, k Keeper, msg *MsgExecuteContract) (*sdk.Result, error) {
func handleExecute(ctx sdk.Context, k *Keeper, msg *MsgExecuteContract) (*sdk.Result, error) {
res, err := k.Execute(ctx, msg.Contract, msg.Sender, msg.Msg, msg.SentFunds)
if err != nil {
return nil, err
Expand All @@ -115,7 +115,7 @@ func handleExecute(ctx sdk.Context, k Keeper, msg *MsgExecuteContract) (*sdk.Res
return res, nil
}

func handleMigration(ctx sdk.Context, k Keeper, msg *MsgMigrateContract) (*sdk.Result, error) {
func handleMigration(ctx sdk.Context, k *Keeper, msg *MsgMigrateContract) (*sdk.Result, error) {
res, err := k.Migrate(ctx, msg.Contract, msg.Sender, msg.CodeID, msg.MigrateMsg)
if err != nil {
return nil, err
Expand All @@ -133,7 +133,7 @@ func handleMigration(ctx sdk.Context, k Keeper, msg *MsgMigrateContract) (*sdk.R
return res, nil
}

func handleUpdateContractAdmin(ctx sdk.Context, k Keeper, msg *MsgUpdateAdmin) (*sdk.Result, error) {
func handleUpdateContractAdmin(ctx sdk.Context, k *Keeper, msg *MsgUpdateAdmin) (*sdk.Result, error) {
if err := k.UpdateContractAdmin(ctx, msg.Contract, msg.Sender, msg.NewAdmin); err != nil {
return nil, err
}
Expand All @@ -149,7 +149,7 @@ func handleUpdateContractAdmin(ctx sdk.Context, k Keeper, msg *MsgUpdateAdmin) (
}, nil
}

func handleClearContractAdmin(ctx sdk.Context, k Keeper, msg *MsgClearAdmin) (*sdk.Result, error) {
func handleClearContractAdmin(ctx sdk.Context, k *Keeper, msg *MsgClearAdmin) (*sdk.Result, error) {
if err := k.ClearContractAdmin(ctx, msg.Contract, msg.Sender); err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// InitGenesis sets supply information for genesis.
//
// CONTRACT: all types of accounts must have been already initialized/created
func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error {
func InitGenesis(ctx sdk.Context, keeper *Keeper, data types.GenesisState) error {
var maxCodeID uint64
for i, code := range data.Codes {
err := keeper.importCode(ctx, code.CodeID, code.CodeInfo, code.CodeBytes)
Expand Down Expand Up @@ -52,7 +52,7 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error
}

// ExportGenesis returns a GenesisState for a given context and keeper.
func ExportGenesis(ctx sdk.Context, keeper Keeper) *types.GenesisState {
func ExportGenesis(ctx sdk.Context, keeper *Keeper) *types.GenesisState {
var genState types.GenesisState

genState.Params = keeper.GetParams(ctx)
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func TestImportContractWithCodeHistoryReset(t *testing.T) {
assert.Equal(t, expHistory, keeper.GetContractHistory(ctx, contractAddr).CodeHistoryEntries)
}

func setupKeeper(t *testing.T) (Keeper, sdk.Context, []sdk.StoreKey, func()) {
func setupKeeper(t *testing.T) (*Keeper, sdk.Context, []sdk.StoreKey, func()) {
t.Helper()
tempDir, err := ioutil.TempDir("", "wasm")
require.NoError(t, err)
Expand Down Expand Up @@ -503,5 +503,5 @@ func setupKeeper(t *testing.T) (Keeper, sdk.Context, []sdk.StoreKey, func()) {
srcKeeper := NewKeeper(encodingConfig.Marshaler, keyWasm, pk.Subspace(wasmTypes.DefaultParamspace), authkeeper.AccountKeeper{}, nil, stakingkeeper.Keeper{}, distributionkeeper.Keeper{}, nil, tempDir, wasmConfig, "", nil, nil)
srcKeeper.setParams(ctx, wasmTypes.DefaultParams())

return srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams}, cleanup
return &srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams}, cleanup
}
49 changes: 22 additions & 27 deletions x/wasm/internal/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,38 +436,33 @@ func TestInstantiateWithNonExistingCodeID(t *testing.T) {
func TestInstantiateWithCallbackToContract(t *testing.T) {
ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil)
var (
wasmerMock = &selfCallingInstMockWasmer{}
err error
excuteCalled bool
err error
)
keepers.WasmKeeperRef.wasmer = wasmerMock
wasmerMock := &MockWasmer{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this constructor over extending AlwaysFailMockWasmer. But it does seem a bit odd to inline this.. Fine as now, as it is only used once, but maybe we can make it reusable for other tests. My rough thought:

func selfCallingInstMockWasmer(executeCalled *bool) *MockWasmer {
    return &MockWasmer{

}	CreateFn: func(code cosmwasm.WasmCode) (cosmwasm.CodeID, error) {
			anyCodeID := bytes.Repeat([]byte{0x1}, 32)
			return anyCodeID, nil
		},
		InstantiateFn: func(code cosmwasm.CodeID, env wasmTypes.Env, info wasmTypes.MessageInfo, initMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) (*wasmTypes.InitResponse, uint64, error) {
			return &wasmTypes.InitResponse{
				Messages: []wasmTypes.CosmosMsg{
					{Wasm: &wasmTypes.WasmMsg{Execute: &wasmTypes.ExecuteMsg{ContractAddr: env.Contract.Address, Msg: []byte(`{}`)}}},
				},
			}, 1, nil
		},
		ExecuteFn: func(code cosmwasm.CodeID, env wasmTypes.Env, info wasmTypes.MessageInfo, executeMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) (*wasmTypes.HandleResponse, uint64, error) {
			excuteCalled = true
			return &wasmTypes.HandleResponse{}, 1, nil
		},
	}
}

And in this test we could do:

executedCalled := false
wasmerMock = selfCallingInstMockWasmer(&executeCalled)
// ...
_, err := keepers.WasmKeeper.Instantiate(ctx, example.CodeID, example.CreatorAddr, nil, nil, "test", nil)
assert.True(t, excuteCalled)

Not a blocker and this can always be refactored later. I like the use of MockWasmer.

CreateFn: func(code cosmwasm.WasmCode) (cosmwasm.CodeID, error) {
anyCodeID := bytes.Repeat([]byte{0x1}, 32)
return anyCodeID, nil
},
InstantiateFn: func(code cosmwasm.CodeID, env wasmTypes.Env, info wasmTypes.MessageInfo, initMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) (*wasmTypes.InitResponse, uint64, error) {
return &wasmTypes.InitResponse{
Messages: []wasmTypes.CosmosMsg{
{Wasm: &wasmTypes.WasmMsg{Execute: &wasmTypes.ExecuteMsg{ContractAddr: env.Contract.Address, Msg: []byte(`{}`)}}},
},
}, 1, nil
},
ExecuteFn: func(code cosmwasm.CodeID, env wasmTypes.Env, info wasmTypes.MessageInfo, executeMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) (*wasmTypes.HandleResponse, uint64, error) {
excuteCalled = true
return &wasmTypes.HandleResponse{}, 1, nil
},
}

keepers.WasmKeeper.wasmer = wasmerMock
keepers.WasmKeeper.wasmer = wasmerMock
example := StoreHackatomExampleContract(t, ctx, keepers)
_, err = keepers.WasmKeeper.Instantiate(ctx, example.CodeID, example.CreatorAddr, nil, nil, "test", nil)
require.NoError(t, err)
assert.True(t, wasmerMock.excuteCalled)
}

// mock to call itself on instantiation
type selfCallingInstMockWasmer struct {
AlwaysFailMockWasmer
excuteCalled bool
}

func (a *selfCallingInstMockWasmer) Create(code cosmwasm.WasmCode) (cosmwasm.CodeID, error) {
return bytes.Repeat([]byte{0x1}, 32), nil
}

func (a *selfCallingInstMockWasmer) Instantiate(code cosmwasm.CodeID, env wasmTypes.Env, info wasmTypes.MessageInfo, initMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) (*wasmTypes.InitResponse, uint64, error) {
return &wasmTypes.InitResponse{
Messages: []wasmTypes.CosmosMsg{
{Wasm: &wasmTypes.WasmMsg{Execute: &wasmTypes.ExecuteMsg{ContractAddr: env.Contract.Address, Msg: []byte(`{}`)}}},
},
}, 1, nil
}

func (a *selfCallingInstMockWasmer) Execute(code cosmwasm.CodeID, env wasmTypes.Env, info wasmTypes.MessageInfo, executeMsg []byte, store cosmwasm.KVStore, goapi cosmwasm.GoAPI, querier cosmwasm.Querier, gasMeter cosmwasm.GasMeter, gasLimit uint64) (*wasmTypes.HandleResponse, uint64, error) {
a.excuteCalled = true
return &wasmTypes.HandleResponse{}, 1, nil
assert.True(t, excuteCalled)
}

func TestExecute(t *testing.T) {
Expand Down
12 changes: 6 additions & 6 deletions x/wasm/internal/keeper/legacy_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
)

// NewLegacyQuerier creates a new querier
func NewLegacyQuerier(keeper Keeper) sdk.Querier {
func NewLegacyQuerier(keeper *Keeper) sdk.Querier {
return func(ctx sdk.Context, path []string, req abci.RequestQuery) ([]byte, error) {
var (
rsp interface{}
Expand All @@ -39,13 +39,13 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, err.Error())
}
rsp, err = queryContractInfo(ctx, addr, keeper)
rsp, err = queryContractInfo(ctx, addr, *keeper)
case QueryListContractByCode:
codeID, err := strconv.ParseUint(path[1], 10, 64)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalid, "code id: %s", err.Error())
}
rsp, err = queryContractListByCode(ctx, codeID, keeper)
rsp, err = queryContractListByCode(ctx, codeID, *keeper)
case QueryGetContractState:
if len(path) < 3 {
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown data query endpoint")
Expand All @@ -58,13 +58,13 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
}
rsp, err = queryCode(ctx, codeID, keeper)
case QueryListCode:
rsp, err = queryCodeList(ctx, keeper)
rsp, err = queryCodeList(ctx, *keeper)
case QueryContractHistory:
contractAddr, err := sdk.AccAddressFromBech32(path[1])
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, err.Error())
}
rsp, err = queryContractHistory(ctx, contractAddr, keeper)
rsp, err = queryContractHistory(ctx, contractAddr, *keeper)
default:
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown data query endpoint")
}
Expand All @@ -82,7 +82,7 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
}
}

func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, keeper Keeper) (json.RawMessage, error) {
func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, keeper *Keeper) (json.RawMessage, error) {
contractAddr, err := sdk.AccAddressFromBech32(bech)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, bech)
Expand Down
14 changes: 7 additions & 7 deletions x/wasm/internal/keeper/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ import (
)

type grpcQuerier struct {
keeper Keeper
keeper *Keeper
}

// todo: this needs proper tests and doc
func NewQuerier(keeper Keeper) grpcQuerier {
func NewQuerier(keeper *Keeper) grpcQuerier {
return grpcQuerier{keeper: keeper}
}

func (q grpcQuerier) ContractInfo(c context.Context, req *types.QueryContractInfoRequest) (*types.QueryContractInfoResponse, error) {
if err := sdk.VerifyAddressFormat(req.Address); err != nil {
return nil, err
}
rsp, err := queryContractInfo(sdk.UnwrapSDKContext(c), req.Address, q.keeper)
rsp, err := queryContractInfo(sdk.UnwrapSDKContext(c), req.Address, *q.keeper)
switch {
case err != nil:
return nil, err
Expand All @@ -40,7 +40,7 @@ func (q grpcQuerier) ContractHistory(c context.Context, req *types.QueryContract
if err := sdk.VerifyAddressFormat(req.Address); err != nil {
return nil, err
}
rsp, err := queryContractHistory(sdk.UnwrapSDKContext(c), req.Address, q.keeper)
rsp, err := queryContractHistory(sdk.UnwrapSDKContext(c), req.Address, *q.keeper)
switch {
case err != nil:
return nil, err
Expand All @@ -56,7 +56,7 @@ func (q grpcQuerier) ContractsByCode(c context.Context, req *types.QueryContract
if req.CodeId == 0 {
return nil, sdkerrors.Wrap(types.ErrInvalid, "code id")
}
rsp, err := queryContractListByCode(sdk.UnwrapSDKContext(c), req.CodeId, q.keeper)
rsp, err := queryContractListByCode(sdk.UnwrapSDKContext(c), req.CodeId, *q.keeper)
switch {
case err != nil:
return nil, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func (q grpcQuerier) Code(c context.Context, req *types.QueryCodeRequest) (*type
}

func (q grpcQuerier) Codes(c context.Context, _ *empty.Empty) (*types.QueryCodesResponse, error) {
rsp, err := queryCodeList(sdk.UnwrapSDKContext(c), q.keeper)
rsp, err := queryCodeList(sdk.UnwrapSDKContext(c), *q.keeper)
switch {
case err != nil:
return nil, err
Expand Down Expand Up @@ -182,7 +182,7 @@ func queryContractListByCode(ctx sdk.Context, codeID uint64, keeper Keeper) ([]t
return contracts, nil
}

func queryCode(ctx sdk.Context, codeID uint64, keeper Keeper) (*types.QueryCodeResponse, error) {
func queryCode(ctx sdk.Context, codeID uint64, keeper *Keeper) (*types.QueryCodeResponse, error) {
if codeID == 0 {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/recurse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type recurseResponse struct {
// number os wasm queries called from a contract
var totalWasmQueryCounter int

func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper Keeper) {
func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper *Keeper) {
// we do one basic setup before all test cases (which are read-only and don't change state)
var realWasmQuerier func(ctx sdk.Context, request *wasmTypes.WasmQuery) ([]byte, error)
countingQuerier := &QueryPlugins{
Expand All @@ -48,7 +48,7 @@ func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.Acc

ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, countingQuerier)
keeper = keepers.WasmKeeper
realWasmQuerier = WasmQuerier(&keeper)
realWasmQuerier = WasmQuerier(keeper)

exampleContract := InstantiateHackatomExampleContract(t, ctx, keepers)
return exampleContract.Contract, exampleContract.CreatorAddr, ctx, keeper
Expand Down
2 changes: 1 addition & 1 deletion x/wasm/internal/keeper/staking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func initializeStaking(t *testing.T) initInfo {
ctx: ctx,
accKeeper: accKeeper,
stakingKeeper: stakingKeeper,
wasmKeeper: keeper,
wasmKeeper: *keeper,
distKeeper: k.DistKeeper,
bankKeeper: bankKeeper,
}
Expand Down
Loading