Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support self calling contract on instantiation #300

Merged
merged 3 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
distr.NewAppModule(appCodec, app.distrKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
staking.NewAppModule(appCodec, app.stakingKeeper, app.accountKeeper, app.bankKeeper),
upgrade.NewAppModule(app.upgradeKeeper),
wasm.NewAppModule(app.wasmKeeper),
wasm.NewAppModule(&app.wasmKeeper),
evidence.NewAppModule(app.evidenceKeeper),
ibc.NewAppModule(app.ibcKeeper),
params.NewAppModule(app.paramsKeeper),
Expand Down Expand Up @@ -472,7 +472,7 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
distr.NewAppModule(appCodec, app.distrKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
slashing.NewAppModule(appCodec, app.slashingKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
params.NewAppModule(app.paramsKeeper),
wasm.NewAppModule(app.wasmKeeper),
wasm.NewAppModule(&app.wasmKeeper),
evidence.NewAppModule(app.evidenceKeeper),
ibc.NewAppModule(app.ibcKeeper),
transferModule,
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ func TestInitGenesis(t *testing.T) {
})

// export into genstate
genState := ExportGenesis(data.ctx, data.keeper)
genState := ExportGenesis(data.ctx, &data.keeper)

// create new app to import genstate into
newData := setupTest(t)
q2 := newData.module.LegacyQuerierHandler(nil)

// initialize new app with genstate
InitGenesis(newData.ctx, newData.keeper, *genState)
InitGenesis(newData.ctx, &newData.keeper, *genState)

// run same checks again on newdata, to make sure it was reinitialized correctly
assertCodeList(t, q2, newData.ctx, 1)
Expand Down
14 changes: 7 additions & 7 deletions x/wasm/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

// NewHandler returns a handler for "bank" type messages.
func NewHandler(k Keeper) sdk.Handler {
func NewHandler(k *Keeper) sdk.Handler {
return func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())

Expand Down Expand Up @@ -47,7 +47,7 @@ func filteredMessageEvents(manager *sdk.EventManager) []abci.Event {
return res
}

func handleStoreCode(ctx sdk.Context, k Keeper, msg *MsgStoreCode) (*sdk.Result, error) {
func handleStoreCode(ctx sdk.Context, k *Keeper, msg *MsgStoreCode) (*sdk.Result, error) {
err := msg.ValidateBasic()
if err != nil {
return nil, err
Expand All @@ -73,7 +73,7 @@ func handleStoreCode(ctx sdk.Context, k Keeper, msg *MsgStoreCode) (*sdk.Result,
}, nil
}

func handleInstantiate(ctx sdk.Context, k Keeper, msg *MsgInstantiateContract) (*sdk.Result, error) {
func handleInstantiate(ctx sdk.Context, k *Keeper, msg *MsgInstantiateContract) (*sdk.Result, error) {
contractAddr, err := k.Instantiate(ctx, msg.CodeID, msg.Sender, msg.Admin, msg.InitMsg, msg.Label, msg.InitFunds)
if err != nil {
return nil, err
Expand All @@ -95,7 +95,7 @@ func handleInstantiate(ctx sdk.Context, k Keeper, msg *MsgInstantiateContract) (
}, nil
}

func handleExecute(ctx sdk.Context, k Keeper, msg *MsgExecuteContract) (*sdk.Result, error) {
func handleExecute(ctx sdk.Context, k *Keeper, msg *MsgExecuteContract) (*sdk.Result, error) {
res, err := k.Execute(ctx, msg.Contract, msg.Sender, msg.Msg, msg.SentFunds)
if err != nil {
return nil, err
Expand All @@ -115,7 +115,7 @@ func handleExecute(ctx sdk.Context, k Keeper, msg *MsgExecuteContract) (*sdk.Res
return res, nil
}

func handleMigration(ctx sdk.Context, k Keeper, msg *MsgMigrateContract) (*sdk.Result, error) {
func handleMigration(ctx sdk.Context, k *Keeper, msg *MsgMigrateContract) (*sdk.Result, error) {
res, err := k.Migrate(ctx, msg.Contract, msg.Sender, msg.CodeID, msg.MigrateMsg)
if err != nil {
return nil, err
Expand All @@ -133,7 +133,7 @@ func handleMigration(ctx sdk.Context, k Keeper, msg *MsgMigrateContract) (*sdk.R
return res, nil
}

func handleUpdateContractAdmin(ctx sdk.Context, k Keeper, msg *MsgUpdateAdmin) (*sdk.Result, error) {
func handleUpdateContractAdmin(ctx sdk.Context, k *Keeper, msg *MsgUpdateAdmin) (*sdk.Result, error) {
if err := k.UpdateContractAdmin(ctx, msg.Contract, msg.Sender, msg.NewAdmin); err != nil {
return nil, err
}
Expand All @@ -149,7 +149,7 @@ func handleUpdateContractAdmin(ctx sdk.Context, k Keeper, msg *MsgUpdateAdmin) (
}, nil
}

func handleClearContractAdmin(ctx sdk.Context, k Keeper, msg *MsgClearAdmin) (*sdk.Result, error) {
func handleClearContractAdmin(ctx sdk.Context, k *Keeper, msg *MsgClearAdmin) (*sdk.Result, error) {
if err := k.ClearContractAdmin(ctx, msg.Contract, msg.Sender); err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// InitGenesis sets supply information for genesis.
//
// CONTRACT: all types of accounts must have been already initialized/created
func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error {
func InitGenesis(ctx sdk.Context, keeper *Keeper, data types.GenesisState) error {
var maxCodeID uint64
for i, code := range data.Codes {
err := keeper.importCode(ctx, code.CodeID, code.CodeInfo, code.CodeBytes)
Expand Down Expand Up @@ -52,7 +52,7 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error
}

// ExportGenesis returns a GenesisState for a given context and keeper.
func ExportGenesis(ctx sdk.Context, keeper Keeper) *types.GenesisState {
func ExportGenesis(ctx sdk.Context, keeper *Keeper) *types.GenesisState {
var genState types.GenesisState

genState.Params = keeper.GetParams(ctx)
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func TestImportContractWithCodeHistoryReset(t *testing.T) {
assert.Equal(t, expHistory, keeper.GetContractHistory(ctx, contractAddr).CodeHistoryEntries)
}

func setupKeeper(t *testing.T) (Keeper, sdk.Context, []sdk.StoreKey, func()) {
func setupKeeper(t *testing.T) (*Keeper, sdk.Context, []sdk.StoreKey, func()) {
t.Helper()
tempDir, err := ioutil.TempDir("", "wasm")
require.NoError(t, err)
Expand Down Expand Up @@ -503,5 +503,5 @@ func setupKeeper(t *testing.T) (Keeper, sdk.Context, []sdk.StoreKey, func()) {
srcKeeper := NewKeeper(encodingConfig.Marshaler, keyWasm, pk.Subspace(wasmTypes.DefaultParamspace), authkeeper.AccountKeeper{}, nil, stakingkeeper.Keeper{}, distributionkeeper.Keeper{}, nil, tempDir, wasmConfig, "", nil, nil)
srcKeeper.setParams(ctx, wasmTypes.DefaultParams())

return srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams}, cleanup
return &srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams}, cleanup
}
16 changes: 9 additions & 7 deletions x/wasm/internal/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Keeper struct {
accountKeeper authkeeper.AccountKeeper
bankKeeper bankkeeper.Keeper

wasmer wasm.Wasmer
wasmer types.WasmerEngine
queryPlugins QueryPlugins
messenger MessageHandler
// queryGasLimit is the max wasm gas that can be spent on executing a query with a contract
Expand Down Expand Up @@ -86,7 +86,7 @@ func NewKeeper(
keeper := Keeper{
storeKey: storeKey,
cdc: cdc,
wasmer: *wasmer,
wasmer: wasmer,
accountKeeper: accountKeeper,
bankKeeper: bankKeeper,
messenger: NewMessageHandler(router, customEncoders),
Expand Down Expand Up @@ -248,16 +248,18 @@ func (k Keeper) instantiate(ctx sdk.Context, codeID uint64, creator, admin sdk.A
events := types.ParseEvents(res.Attributes, contractAddress)
ctx.EventManager().EmitEvents(events)

// persist instance first
createdAt := types.NewAbsoluteTxPosition(ctx)
instance := types.NewContractInfo(codeID, creator, admin, label, createdAt)
store.Set(types.GetContractAddressKey(contractAddress), k.cdc.MustMarshalBinaryBare(&instance))
k.appendToContractHistory(ctx, contractAddress, instance.InitialHistory(initMsg))

// then dispatch so that contract could be called back
err = k.dispatchMessages(ctx, contractAddress, res.Messages)
if err != nil {
return nil, err
}

// persist instance
createdAt := types.NewAbsoluteTxPosition(ctx)
instance := types.NewContractInfo(codeID, creator, admin, label, createdAt)
store.Set(types.GetContractAddressKey(contractAddress), k.cdc.MustMarshalBinaryBare(&instance))
k.appendToContractHistory(ctx, contractAddress, instance.InitialHistory(initMsg))
return contractAddress, nil
}

Expand Down
15 changes: 15 additions & 0 deletions x/wasm/internal/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,21 @@ func TestInstantiateWithNonExistingCodeID(t *testing.T) {
require.Nil(t, addr)
}

func TestInstantiateWithCallbackToContract(t *testing.T) {
ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil)
var (
executeCalled bool
err error
)
wasmerMock := selfCallingInstMockWasmer(&executeCalled)

keepers.WasmKeeper.wasmer = wasmerMock
example := StoreHackatomExampleContract(t, ctx, keepers)
_, err = keepers.WasmKeeper.Instantiate(ctx, example.CodeID, example.CreatorAddr, nil, nil, "test", nil)
require.NoError(t, err)
assert.True(t, executeCalled)
}

func TestExecute(t *testing.T) {
ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil)
accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper
Expand Down
12 changes: 6 additions & 6 deletions x/wasm/internal/keeper/legacy_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
)

// NewLegacyQuerier creates a new querier
func NewLegacyQuerier(keeper Keeper) sdk.Querier {
func NewLegacyQuerier(keeper *Keeper) sdk.Querier {
return func(ctx sdk.Context, path []string, req abci.RequestQuery) ([]byte, error) {
var (
rsp interface{}
Expand All @@ -39,13 +39,13 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, err.Error())
}
rsp, err = queryContractInfo(ctx, addr, keeper)
rsp, err = queryContractInfo(ctx, addr, *keeper)
case QueryListContractByCode:
codeID, err := strconv.ParseUint(path[1], 10, 64)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalid, "code id: %s", err.Error())
}
rsp, err = queryContractListByCode(ctx, codeID, keeper)
rsp, err = queryContractListByCode(ctx, codeID, *keeper)
case QueryGetContractState:
if len(path) < 3 {
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown data query endpoint")
Expand All @@ -58,13 +58,13 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
}
rsp, err = queryCode(ctx, codeID, keeper)
case QueryListCode:
rsp, err = queryCodeList(ctx, keeper)
rsp, err = queryCodeList(ctx, *keeper)
case QueryContractHistory:
contractAddr, err := sdk.AccAddressFromBech32(path[1])
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, err.Error())
}
rsp, err = queryContractHistory(ctx, contractAddr, keeper)
rsp, err = queryContractHistory(ctx, contractAddr, *keeper)
default:
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown data query endpoint")
}
Expand All @@ -82,7 +82,7 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
}
}

func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, keeper Keeper) (json.RawMessage, error) {
func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, keeper *Keeper) (json.RawMessage, error) {
contractAddr, err := sdk.AccAddressFromBech32(bech)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, bech)
Expand Down
14 changes: 7 additions & 7 deletions x/wasm/internal/keeper/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ import (
)

type grpcQuerier struct {
keeper Keeper
keeper *Keeper
}

// todo: this needs proper tests and doc
func NewQuerier(keeper Keeper) grpcQuerier {
func NewQuerier(keeper *Keeper) grpcQuerier {
return grpcQuerier{keeper: keeper}
}

func (q grpcQuerier) ContractInfo(c context.Context, req *types.QueryContractInfoRequest) (*types.QueryContractInfoResponse, error) {
if err := sdk.VerifyAddressFormat(req.Address); err != nil {
return nil, err
}
rsp, err := queryContractInfo(sdk.UnwrapSDKContext(c), req.Address, q.keeper)
rsp, err := queryContractInfo(sdk.UnwrapSDKContext(c), req.Address, *q.keeper)
switch {
case err != nil:
return nil, err
Expand All @@ -40,7 +40,7 @@ func (q grpcQuerier) ContractHistory(c context.Context, req *types.QueryContract
if err := sdk.VerifyAddressFormat(req.Address); err != nil {
return nil, err
}
rsp, err := queryContractHistory(sdk.UnwrapSDKContext(c), req.Address, q.keeper)
rsp, err := queryContractHistory(sdk.UnwrapSDKContext(c), req.Address, *q.keeper)
switch {
case err != nil:
return nil, err
Expand All @@ -56,7 +56,7 @@ func (q grpcQuerier) ContractsByCode(c context.Context, req *types.QueryContract
if req.CodeId == 0 {
return nil, sdkerrors.Wrap(types.ErrInvalid, "code id")
}
rsp, err := queryContractListByCode(sdk.UnwrapSDKContext(c), req.CodeId, q.keeper)
rsp, err := queryContractListByCode(sdk.UnwrapSDKContext(c), req.CodeId, *q.keeper)
switch {
case err != nil:
return nil, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func (q grpcQuerier) Code(c context.Context, req *types.QueryCodeRequest) (*type
}

func (q grpcQuerier) Codes(c context.Context, _ *empty.Empty) (*types.QueryCodesResponse, error) {
rsp, err := queryCodeList(sdk.UnwrapSDKContext(c), q.keeper)
rsp, err := queryCodeList(sdk.UnwrapSDKContext(c), *q.keeper)
switch {
case err != nil:
return nil, err
Expand Down Expand Up @@ -182,7 +182,7 @@ func queryContractListByCode(ctx sdk.Context, codeID uint64, keeper Keeper) ([]t
return contracts, nil
}

func queryCode(ctx sdk.Context, codeID uint64, keeper Keeper) (*types.QueryCodeResponse, error) {
func queryCode(ctx sdk.Context, codeID uint64, keeper *Keeper) (*types.QueryCodeResponse, error) {
if codeID == 0 {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/recurse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type recurseResponse struct {
// number os wasm queries called from a contract
var totalWasmQueryCounter int

func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper Keeper) {
func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper *Keeper) {
// we do one basic setup before all test cases (which are read-only and don't change state)
var realWasmQuerier func(ctx sdk.Context, request *wasmTypes.WasmQuery) ([]byte, error)
countingQuerier := &QueryPlugins{
Expand All @@ -48,7 +48,7 @@ func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.Acc

ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, countingQuerier)
keeper = keepers.WasmKeeper
realWasmQuerier = WasmQuerier(&keeper)
realWasmQuerier = WasmQuerier(keeper)

exampleContract := InstantiateHackatomExampleContract(t, ctx, keepers)
return exampleContract.Contract, exampleContract.CreatorAddr, ctx, keeper
Expand Down
2 changes: 1 addition & 1 deletion x/wasm/internal/keeper/staking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func initializeStaking(t *testing.T) initInfo {
ctx: ctx,
accKeeper: accKeeper,
stakingKeeper: stakingKeeper,
wasmKeeper: keeper,
wasmKeeper: *keeper,
distKeeper: k.DistKeeper,
bankKeeper: bankKeeper,
}
Expand Down
Loading