diff --git a/x/wasm/keeper/handler_plugin.go b/x/wasm/keeper/handler_plugin.go index 50bd1a1b67..796fa54d97 100644 --- a/x/wasm/keeper/handler_plugin.go +++ b/x/wasm/keeper/handler_plugin.go @@ -124,6 +124,20 @@ func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Ad return nil, errorsmod.Wrapf(sdkerrors.ErrUnknownRequest, "can't route message %+v", msg) } +type callDepthMessageHandler struct { + Messenger + MaxCallDepth uint32 +} + +func (h callDepthMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, msgResponses [][]*codectypes.Any, err error) { + ctx, err = checkAndIncreaseCallDepth(ctx, h.MaxCallDepth) + if err != nil { + return nil, nil, nil, err + } + + return h.Messenger.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg) +} + // MessageHandlerChain defines a chain of handlers that are called one by one until it can be handled. type MessageHandlerChain struct { handlers []Messenger diff --git a/x/wasm/keeper/keeper.go b/x/wasm/keeper/keeper.go index 443de1b601..2e6cdac178 100644 --- a/x/wasm/keeper/keeper.go +++ b/x/wasm/keeper/keeper.go @@ -98,6 +98,7 @@ type Keeper struct { queryGasLimit uint64 gasRegister types.GasRegister maxQueryStackSize uint32 + maxCallDepth uint32 acceptedAccountTypes map[reflect.Type]struct{} accountPruner AccountPruner params collections.Item[types.Params] @@ -785,6 +786,7 @@ func (k Keeper) mustGetLastContractHistoryEntry(ctx context.Context, contractAdd // QuerySmart queries the smart contract itself. func (k Keeper) QuerySmart(ctx context.Context, contractAddr sdk.AccAddress, req []byte) ([]byte, error) { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "query-smart") + // checks and increase query stack size sdkCtx, err := checkAndIncreaseQueryStackSize(sdk.UnwrapSDKContext(ctx), k.maxQueryStackSize) if err != nil { @@ -832,6 +834,24 @@ func checkAndIncreaseQueryStackSize(ctx context.Context, maxQueryStackSize uint3 return types.WithQueryStackSize(sdk.UnwrapSDKContext(ctx), queryStackSize), nil } +func checkAndIncreaseCallDepth(ctx context.Context, maxCallDepth uint32) (sdk.Context, error) { + var callDepth uint32 = 0 + if size, ok := types.CallDepth(ctx); ok { + callDepth = size + } + + // increase + callDepth++ + + // did we go too far? + if callDepth > maxCallDepth { + return sdk.Context{}, types.ErrExceedMaxCallDepth + } + + // set updated stack size + return types.WithCallDepth(sdk.UnwrapSDKContext(ctx), callDepth), nil +} + // QueryRaw returns the contract's state for give key. Returns `nil` when key is `nil`. func (k Keeper) QueryRaw(ctx context.Context, contractAddress sdk.AccAddress, key []byte) []byte { defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "query-raw") diff --git a/x/wasm/keeper/keeper_cgo.go b/x/wasm/keeper/keeper_cgo.go index 020fef834d..2e0b234e44 100644 --- a/x/wasm/keeper/keeper_cgo.go +++ b/x/wasm/keeper/keeper_cgo.go @@ -51,6 +51,7 @@ func NewKeeper( queryGasLimit: wasmConfig.SmartQueryGasLimit, gasRegister: types.NewDefaultWasmGasRegister(), maxQueryStackSize: types.DefaultMaxQueryStackSize, + maxCallDepth: types.DefaultMaxCallDepth, acceptedAccountTypes: defaultAcceptedAccountTypes, params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)), propagateGovAuthorization: map[types.AuthorizationPolicyAction]struct{}{ @@ -63,6 +64,8 @@ func NewKeeper( for _, o := range preOpts { o.apply(keeper) } + // always wrap the messenger, even if it was replaced by an option + keeper.messenger = callDepthMessageHandler{keeper.messenger, keeper.maxCallDepth} // only set the wasmvm if no one set this in the options // NewVM does a lot, so better not to create it and silently drop it. if keeper.wasmVM == nil { diff --git a/x/wasm/keeper/options.go b/x/wasm/keeper/options.go index 234197be8d..8b001930ef 100644 --- a/x/wasm/keeper/options.go +++ b/x/wasm/keeper/options.go @@ -158,6 +158,12 @@ func WithMaxQueryStackSize(m uint32) Option { }) } +func WithMaxCallDepth(m uint32) Option { + return optsFn(func(k *Keeper) { + k.maxCallDepth = m + }) +} + // WithAcceptedAccountTypesOnContractInstantiation sets the accepted account types. Account types of this list won't be overwritten or cause a failure // when they exist for an address on contract instantiation. // diff --git a/x/wasm/keeper/options_test.go b/x/wasm/keeper/options_test.go index 1ac311c14c..4a780245f0 100644 --- a/x/wasm/keeper/options_test.go +++ b/x/wasm/keeper/options_test.go @@ -59,7 +59,9 @@ func TestConstructorOptions(t *testing.T) { "message handler": { srcOpt: WithMessageHandler(&wasmtesting.MockMessageHandler{}), verify: func(t *testing.T, k Keeper) { - assert.IsType(t, &wasmtesting.MockMessageHandler{}, k.messenger) + require.IsType(t, callDepthMessageHandler{}, k.messenger) + messenger, _ := k.messenger.(callDepthMessageHandler) + assert.IsType(t, &wasmtesting.MockMessageHandler{}, messenger.Messenger) }, }, "query plugins": { @@ -70,7 +72,7 @@ func TestConstructorOptions(t *testing.T) { }, "message handler decorator": { srcOpt: WithMessageHandlerDecorator(func(old Messenger) Messenger { - require.IsType(t, &MessageHandlerChain{}, old) + require.IsType(t, callDepthMessageHandler{}, old) return &wasmtesting.MockMessageHandler{} }), verify: func(t *testing.T, k Keeper) { @@ -108,12 +110,18 @@ func TestConstructorOptions(t *testing.T) { assert.Equal(t, uint64(2), costCanonical) }, }, - "max recursion query limit": { + "max query recursion limit": { srcOpt: WithMaxQueryStackSize(1), verify: func(t *testing.T, k Keeper) { assert.IsType(t, uint32(1), k.maxQueryStackSize) }, }, + "max message recursion limit": { + srcOpt: WithMaxCallDepth(1), + verify: func(t *testing.T, k Keeper) { + assert.IsType(t, uint32(1), k.maxCallDepth) + }, + }, "accepted account types": { srcOpt: WithAcceptedAccountTypesOnContractInstantiation(&authtypes.BaseAccount{}, &vestingtypes.ContinuousVestingAccount{}), verify: func(t *testing.T, k Keeper) { diff --git a/x/wasm/keeper/query_plugins_test.go b/x/wasm/keeper/query_plugins_test.go index 1fea83a25f..a36a8139f0 100644 --- a/x/wasm/keeper/query_plugins_test.go +++ b/x/wasm/keeper/query_plugins_test.go @@ -561,7 +561,7 @@ func TestQueryErrors(t *testing.T) { return nil, spec.src }) ms := store.NewCommitMultiStore(dbm.NewMemDB(), log.NewTestLogger(t), storemetrics.NewNoOpMetrics()) - ctx := sdk.Context{}.WithGasMeter(storetypes.NewInfiniteGasMeter()).WithMultiStore(ms).WithLogger(log.NewTestLogger(t)) + ctx := sdk.NewContext(ms, cmtproto.Header{}, false, log.NewTestLogger(t)).WithGasMeter(storetypes.NewInfiniteGasMeter()) q := keeper.NewQueryHandler(ctx, mock, sdk.AccAddress{}, types.NewDefaultWasmGasRegister()) _, gotErr := q.Query(wasmvmtypes.QueryRequest{}, 1) assert.Equal(t, spec.expErr, gotErr) diff --git a/x/wasm/types/context.go b/x/wasm/types/context.go index 60d5dedc04..002cf01aad 100644 --- a/x/wasm/types/context.go +++ b/x/wasm/types/context.go @@ -18,6 +18,8 @@ const ( contextKeySubMsgAuthzPolicy = iota // gas register contextKeyGasRegister = iota + + contextKeyCallDepth contextKey = iota ) // WithTXCounter stores a transaction counter value in the context @@ -43,6 +45,15 @@ func QueryStackSize(ctx context.Context) (uint32, bool) { return val, ok } +func WithCallDepth(ctx sdk.Context, counter uint32) sdk.Context { + return ctx.WithValue(contextKeyCallDepth, counter) +} + +func CallDepth(ctx context.Context) (uint32, bool) { + val, ok := ctx.Value(contextKeyCallDepth).(uint32) + return val, ok +} + // WithSubMsgAuthzPolicy stores the authorization policy for submessages into the context returned func WithSubMsgAuthzPolicy(ctx sdk.Context, policy AuthorizationPolicy) sdk.Context { if policy == nil { diff --git a/x/wasm/types/errors.go b/x/wasm/types/errors.go index 00ea45dc68..7ed79d5c4c 100644 --- a/x/wasm/types/errors.go +++ b/x/wasm/types/errors.go @@ -89,6 +89,9 @@ var ( // ErrVMError means an error occurred in wasmvm (not in the contract itself, but in the host environment) ErrVMError = errorsmod.Register(DefaultCodespace, 29, "wasmvm error") + + // ErrExceedMaxCallDepth error if max message stack size is exceeded + ErrExceedMaxCallDepth = errorsmod.Register(DefaultCodespace, 30, "max call depth exceeded") ) // WasmVMErrorable mapped error type in wasmvm and are not redacted diff --git a/x/wasm/types/wasmer_engine.go b/x/wasm/types/wasmer_engine.go index f3829e9146..9d58a77078 100644 --- a/x/wasm/types/wasmer_engine.go +++ b/x/wasm/types/wasmer_engine.go @@ -7,9 +7,11 @@ import ( storetypes "cosmossdk.io/store/types" ) -// DefaultMaxQueryStackSize maximum size of the stack of contract instances doing queries +// DefaultMaxQueryStackSize maximum size of the stack of recursive queries a contract can make const DefaultMaxQueryStackSize uint32 = 10 +const DefaultMaxCallDepth uint32 = 500 + // WasmEngine defines the WASM contract runtime engine. type WasmEngine interface { // StoreCode will compile the Wasm code, and store the resulting compiled module