diff --git a/CMakeLists.txt b/CMakeLists.txt index f0c69b5aa..87a29df19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ if(CMAKE_BUILD_TYPE MATCHES "Debug") set(GFLAG -tags Debug) endif() -add_custom_target(build ALL DEPENDS aergocli aergosvr aergoluac brick) +add_custom_target(build ALL DEPENDS aergocli aergosvr aergovm aergoluac brick) add_custom_target(aergocli GO111MODULE=on GOBIN=${BIN_DIR} go install ${GCFLAGS} -ldflags \"-X github.com/aergoio/aergo/v2/cmd/aergocli/cmd.githash=`git describe --tags`\" ./cmd/aergocli/... WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} @@ -25,6 +25,12 @@ add_custom_target(aergosvr GO111MODULE=on GOBIN=${BIN_DIR} go install ${GCFLAGS} WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} DEPENDS libtool) +add_custom_target(aergovm + COMMAND GO111MODULE=on GOBIN=${BIN_DIR} go install ${GCFLAGS} -ldflags \"-X main.githash=`git describe --tags` -X main.gitRevision=`git rev-parse --short HEAD` -X main.gitBranch=`git rev-parse --symbolic-full-name --abbrev-ref HEAD`\" ./contract/vm/... + COMMAND ${CMAKE_COMMAND} -E rename ${BIN_DIR}/vm ${BIN_DIR}/aergovm + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + DEPENDS libtool) + add_custom_target(polaris GO111MODULE=on GOBIN=${BIN_DIR} go install ${GCFLAGS} -ldflags \"-X main.githash=`git describe --tags`\" ./cmd/polaris/... WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) @@ -45,7 +51,8 @@ add_custom_target(mpdumpdiag GO111MODULE=on GOBIN=${BIN_DIR} go install ${GCFLAG add_custom_target(deps DEPENDS libtool) -add_custom_target(check GO111MODULE=on go test -timeout 600s ./... +add_custom_target(check + COMMAND ${CMAKE_COMMAND} -E env AERGOVM_PATH=${BIN_DIR}/aergovm GO111MODULE=on go test -timeout 600s ./... WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} DEPENDS build) add_custom_target(cover-check GO111MODULE=on go test -timeout 600s -coverprofile c.out ./... diff --git a/Makefile b/Makefile index b23750fe1..9f119fe2b 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ endif BUILD_RULES := \ deps \ - aergocli aergosvr aergoluac polaris colaris brick mpdumpdiag\ + aergocli aergosvr aergovm aergoluac polaris colaris brick mpdumpdiag \ libtool libtool-clean \ libluajit liblmdb libgmp \ libluajit-clean liblmdb-clean libgmp-clean \ diff --git a/chain/chainservice.go b/chain/chainservice.go index d69fd8af7..11feaa459 100644 --- a/chain/chainservice.go +++ b/chain/chainservice.go @@ -154,7 +154,7 @@ func (core *Core) GetGenesisInfo() *types.Genesis { return core.cdb.GetGenesisInfo() } -// Close closes chain & state DB. +// Close closes chain, state and contracts DBs and the VM pool func (core *Core) Close() { if core.sdb != nil { core.sdb.Close() @@ -163,6 +163,7 @@ func (core *Core) Close() { core.cdb.Close() } contract.CloseDatabase() + contract.StopVMPool() } // InitGenesisBlock initialize chain database and generate specified genesis block if necessary @@ -295,7 +296,7 @@ func NewChainService(cfg *cfg.Config) *ChainService { contract.PubNet = pubNet contract.TraceBlockNo = cfg.Blockchain.StateTrace contract.SetStateSQLMaxDBSize(cfg.SQL.MaxDbSize) - contract.StartLStateFactory((cfg.Blockchain.NumWorkers+2)*(int(contract.MaxCallDepth(cfg.Hardfork.Version(math.MaxUint64)))+2), cfg.Blockchain.NumLStateClosers, cfg.Blockchain.CloseLimit) + contract.StartVMPool((cfg.Blockchain.NumWorkers + 2) * int(contract.MaxCallDepth(cfg.Hardfork.Version(math.MaxUint64)))) contract.InitContext(cfg.Blockchain.NumWorkers + 2) // For a strict governance transaction validation. diff --git a/cmd/aergoluac/util/compile.c b/cmd/aergoluac/luac/compile.c similarity index 98% rename from cmd/aergoluac/util/compile.c rename to cmd/aergoluac/luac/compile.c index 0f645b0b0..778f9aff2 100644 --- a/cmd/aergoluac/util/compile.c +++ b/cmd/aergoluac/luac/compile.c @@ -8,7 +8,7 @@ #include "_cgo_export.h" lua_State *luac_vm_newstate() { - lua_State *L = luaL_newstate(3); + lua_State *L = luaL_newstate(5); if (L == NULL) { return NULL; } diff --git a/cmd/aergoluac/util/compile.h b/cmd/aergoluac/luac/compile.h similarity index 100% rename from cmd/aergoluac/util/compile.h rename to cmd/aergoluac/luac/compile.h diff --git a/cmd/aergoluac/luac/luac.go b/cmd/aergoluac/luac/luac.go new file mode 100644 index 000000000..1482992e1 --- /dev/null +++ b/cmd/aergoluac/luac/luac.go @@ -0,0 +1,131 @@ +package luac + +/* +#cgo CFLAGS: -I${SRCDIR}/../../../libtool/include/luajit-2.1 +#cgo LDFLAGS: ${SRCDIR}/../../../libtool/lib/libluajit-5.1.a -lm + +#include +#include +#include "compile.h" +*/ +import "C" +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io/ioutil" + "os" + "runtime" + "unsafe" + + "github.com/aergoio/aergo/v2/cmd/aergoluac/encoding" + "github.com/aergoio/aergo/v2/cmd/aergoluac/util" +) + +func NewLState() *C.lua_State { + L := C.luac_vm_newstate() + if L == nil { + runtime.GC() + L = C.luac_vm_newstate() + } + return L +} + +func CloseLState(L *C.lua_State) { + if L != nil { + C.luac_vm_close(L) + } +} + +func Compile(L *C.lua_State, code string) (util.LuaCode, error) { + cStr := C.CString(code) + defer C.free(unsafe.Pointer(cStr)) + if errMsg := C.vm_loadstring(L, cStr); errMsg != nil { + return nil, errors.New(C.GoString(errMsg)) + } + if errMsg := C.vm_stringdump(L); errMsg != nil { + return nil, errors.New(C.GoString(errMsg)) + } + return dumpToBytes(L), nil +} + +func CompileFromFile(srcFileName, outFileName, abiFileName string) error { + cSrcFileName := C.CString(srcFileName) + cOutFileName := C.CString(outFileName) + cAbiFileName := C.CString(abiFileName) + L := C.luac_vm_newstate() + defer C.free(unsafe.Pointer(cSrcFileName)) + defer C.free(unsafe.Pointer(cOutFileName)) + defer C.free(unsafe.Pointer(cAbiFileName)) + defer C.luac_vm_close(L) + + if errMsg := C.vm_compile(L, cSrcFileName, cOutFileName, cAbiFileName); errMsg != nil { + return errors.New(C.GoString(errMsg)) + } + return nil +} + +func DumpFromFile(srcFileName string) error { + cSrcFileName := C.CString(srcFileName) + L := C.luac_vm_newstate() + defer C.free(unsafe.Pointer(cSrcFileName)) + defer C.luac_vm_close(L) + + if errMsg := C.vm_loadfile(L, cSrcFileName); errMsg != nil { + return errors.New(C.GoString(errMsg)) + } + if errMsg := C.vm_stringdump(L); errMsg != nil { + return errors.New(C.GoString(errMsg)) + } + + fmt.Println(encoding.EncodeCode(dumpToBytes(L))) + return nil +} + +func DumpFromStdin() error { + fi, err := os.Stdin.Stat() + if err != nil { + return err + } + var buf []byte + if (fi.Mode() & os.ModeCharDevice) == 0 { + buf, err = ioutil.ReadAll(os.Stdin) + if err != nil { + return err + } + } else { + var bBuf bytes.Buffer + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + bBuf.WriteString(scanner.Text() + "\n") + } + if err = scanner.Err(); err != nil { + return err + } + buf = bBuf.Bytes() + } + srcCode := C.CString(string(buf)) + L := C.luac_vm_newstate() + defer C.free(unsafe.Pointer(srcCode)) + defer C.luac_vm_close(L) + + if errMsg := C.vm_loadstring(L, srcCode); errMsg != nil { + return errors.New(C.GoString(errMsg)) + } + if errMsg := C.vm_stringdump(L); errMsg != nil { + return errors.New(C.GoString(errMsg)) + } + fmt.Println(encoding.EncodeCode(dumpToBytes(L))) + return nil +} + +func dumpToBytes(L *C.lua_State) util.LuaCode { + var ( + c, a *C.char + lc, la C.size_t + ) + c = C.lua_tolstring(L, -2, &lc) + a = C.lua_tolstring(L, -1, &la) + return util.NewLuaCode(C.GoBytes(unsafe.Pointer(c), C.int(lc)), C.GoBytes(unsafe.Pointer(a), C.int(la))) +} diff --git a/cmd/aergoluac/util/state_module.c b/cmd/aergoluac/luac/state_module.c similarity index 100% rename from cmd/aergoluac/util/state_module.c rename to cmd/aergoluac/luac/state_module.c diff --git a/cmd/aergoluac/util/state_module.h b/cmd/aergoluac/luac/state_module.h similarity index 100% rename from cmd/aergoluac/util/state_module.h rename to cmd/aergoluac/luac/state_module.h diff --git a/cmd/aergoluac/main.go b/cmd/aergoluac/main.go index 1c4c120b3..a0ca3e818 100644 --- a/cmd/aergoluac/main.go +++ b/cmd/aergoluac/main.go @@ -10,6 +10,7 @@ import ( "os" "github.com/aergoio/aergo/v2/cmd/aergoluac/util" + "github.com/aergoio/aergo/v2/cmd/aergoluac/luac" "github.com/spf13/cobra" ) @@ -44,15 +45,15 @@ func init() { } } else if payload { if len(args) == 0 { - err = util.DumpFromStdin() + err = luac.DumpFromStdin() } else { - err = util.DumpFromFile(args[0]) + err = luac.DumpFromFile(args[0]) } } else { if len(args) < 2 { return errors.New("2 arguments required: ") } - err = util.CompileFromFile(args[0], args[1], abiFile) + err = luac.CompileFromFile(args[0], args[1], abiFile) } return err diff --git a/cmd/aergoluac/util/luac_util.go b/cmd/aergoluac/util/util.go similarity index 62% rename from cmd/aergoluac/util/luac_util.go rename to cmd/aergoluac/util/util.go index 9fc9daa89..4cef4a668 100644 --- a/cmd/aergoluac/util/luac_util.go +++ b/cmd/aergoluac/util/util.go @@ -1,137 +1,25 @@ package util -/* -#cgo CFLAGS: -I${SRCDIR}/../../../libtool/include/luajit-2.1 -#cgo LDFLAGS: ${SRCDIR}/../../../libtool/lib/libluajit-5.1.a -lm - -#include -#include -#include "compile.h" -*/ -import "C" import ( "bufio" "bytes" "encoding/binary" - "errors" "fmt" "io/ioutil" "os" - "runtime" - "unsafe" "github.com/aergoio/aergo/v2/internal/enc/hex" "github.com/aergoio/aergo/v2/internal/enc/base58" "github.com/aergoio/aergo/v2/cmd/aergoluac/encoding" ) -func NewLState() *C.lua_State { - L := C.luac_vm_newstate() - if L == nil { - runtime.GC() - L = C.luac_vm_newstate() - } - return L -} - -func CloseLState(L *C.lua_State) { - if L != nil { - C.luac_vm_close(L) - } -} - -func Compile(L *C.lua_State, code string) (LuaCode, error) { - cStr := C.CString(code) - defer C.free(unsafe.Pointer(cStr)) - if errMsg := C.vm_loadstring(L, cStr); errMsg != nil { - return nil, errors.New(C.GoString(errMsg)) - } - if errMsg := C.vm_stringdump(L); errMsg != nil { - return nil, errors.New(C.GoString(errMsg)) - } - return dumpToBytes(L), nil -} - -func CompileFromFile(srcFileName, outFileName, abiFileName string) error { - cSrcFileName := C.CString(srcFileName) - cOutFileName := C.CString(outFileName) - cAbiFileName := C.CString(abiFileName) - L := C.luac_vm_newstate() - defer C.free(unsafe.Pointer(cSrcFileName)) - defer C.free(unsafe.Pointer(cOutFileName)) - defer C.free(unsafe.Pointer(cAbiFileName)) - defer C.luac_vm_close(L) - - if errMsg := C.vm_compile(L, cSrcFileName, cOutFileName, cAbiFileName); errMsg != nil { - return errors.New(C.GoString(errMsg)) - } - return nil -} -func DumpFromFile(srcFileName string) error { - cSrcFileName := C.CString(srcFileName) - L := C.luac_vm_newstate() - defer C.free(unsafe.Pointer(cSrcFileName)) - defer C.luac_vm_close(L) - - if errMsg := C.vm_loadfile(L, cSrcFileName); errMsg != nil { - return errors.New(C.GoString(errMsg)) - } - if errMsg := C.vm_stringdump(L); errMsg != nil { - return errors.New(C.GoString(errMsg)) - } - - fmt.Println(encoding.EncodeCode(dumpToBytes(L))) - return nil -} - -func DumpFromStdin() error { - fi, err := os.Stdin.Stat() - if err != nil { - return err - } - var buf []byte - if (fi.Mode() & os.ModeCharDevice) == 0 { - buf, err = ioutil.ReadAll(os.Stdin) - if err != nil { - return err - } - } else { - var bBuf bytes.Buffer - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - bBuf.WriteString(scanner.Text() + "\n") - } - if err = scanner.Err(); err != nil { - return err - } - buf = bBuf.Bytes() - } - srcCode := C.CString(string(buf)) - L := C.luac_vm_newstate() - defer C.free(unsafe.Pointer(srcCode)) - defer C.luac_vm_close(L) - - if errMsg := C.vm_loadstring(L, srcCode); errMsg != nil { - return errors.New(C.GoString(errMsg)) - } - if errMsg := C.vm_stringdump(L); errMsg != nil { - return errors.New(C.GoString(errMsg)) - } - fmt.Println(encoding.EncodeCode(dumpToBytes(L))) - return nil -} - -func dumpToBytes(L *C.lua_State) LuaCode { - var ( - c, a *C.char - lc, la C.size_t - ) - c = C.lua_tolstring(L, -2, &lc) - a = C.lua_tolstring(L, -1, &la) - return NewLuaCode(C.GoBytes(unsafe.Pointer(c), C.int(lc)), C.GoBytes(unsafe.Pointer(a), C.int(la))) -} +//////////////////////////////////////////////////////////////////////////////// +// Decode +//////////////////////////////////////////////////////////////////////////////// +// Decode decodes the payload from a hex string or a base58 string or a JSON string +// and writes the bytecode, abi and deploy arguments to files func Decode(srcFileName string, payload string) error { var decoded []byte var err error @@ -223,6 +111,11 @@ func DecodeFromStdin() error { } +//////////////////////////////////////////////////////////////////////////////// +// LuaCode and LuaCodePayload +// used to store bytecode, abi and deploy arguments +//////////////////////////////////////////////////////////////////////////////// + type LuaCode []byte const byteCodeLenLen = 4 diff --git a/cmd/brick/context/context.go b/cmd/brick/context/context.go index 9f2bc330f..8119c1698 100644 --- a/cmd/brick/context/context.go +++ b/cmd/brick/context/context.go @@ -32,7 +32,7 @@ func Open(private bool) { err error ) if privateNet { - chain, err = vm_dummy.LoadDummyChain() + chain, err = vm_dummy.LoadDummyChain(vm_dummy.SetPrivNet()) } else { chain, err = vm_dummy.LoadDummyChain(vm_dummy.SetPubNet()) } diff --git a/config/config.go b/config/config.go index 8c23b2d0b..9ef451c34 100644 --- a/config/config.go +++ b/config/config.go @@ -120,21 +120,6 @@ func (ctx *ServerContext) GetDefaultPolarisConfig() *PolarisConfig { } } -func GetDefaultNumLStateClosers() int { - const ( - occupationFactor = 8 - minClosers = 1 - ) - if n := runtime.NumCPU() / occupationFactor; n >= minClosers { - return n - } - return minClosers -} - -func GetDefaultCloseLimit() int { - return 100 -} - func (ctx *ServerContext) GetDefaultBlockchainConfig() *BlockchainConfig { return &BlockchainConfig{ MaxBlockSize: types.DefaultMaxBlockSize, @@ -145,8 +130,6 @@ func (ctx *ServerContext) GetDefaultBlockchainConfig() *BlockchainConfig { ZeroFee: true, // deprecated StateTrace: 0, NumWorkers: runtime.NumCPU(), - NumLStateClosers: GetDefaultNumLStateClosers(), - CloseLimit: GetDefaultCloseLimit(), } } diff --git a/config/types.go b/config/types.go index f5e6944e6..d735df31c 100644 --- a/config/types.go +++ b/config/types.go @@ -116,8 +116,6 @@ type BlockchainConfig struct { StateTrace uint64 `mapstructure:"statetrace" description:"dump trace of setting state"` VerifyBlock uint64 `mapstructure:"verifyblock" description:"In verify only mode, server verifies given block of disk. server never modifies block chain'"` NumWorkers int `mapstructure:"numworkers" description:"maximum worker count for chainservice"` - NumLStateClosers int `mapstructure:"numclosers" description:"maximum LuaVM state closer count for chainservice"` - CloseLimit int `mapstructure:"closelimit" description:"number of LuaVM states which a LuaVM state closer closes at one time"` } // MempoolConfig defines configurations for mempool service @@ -251,8 +249,6 @@ maxanchorcount = "{{.Blockchain.MaxAnchorCount}}" verifiercount = "{{.Blockchain.VerifierCount}}" forceresetheight = "{{.Blockchain.ForceResetHeight}}" numworkers = "{{.Blockchain.NumWorkers}}" -numclosers = "{{.Blockchain.NumLStateClosers}}" -closelimit = "{{.Blockchain.CloseLimit}}" [mempool] showmetrics = {{.Mempool.ShowMetrics}} diff --git a/contract/db_module.c b/contract/db_module.c index e2ac603dc..a2c5c423a 100644 --- a/contract/db_module.c +++ b/contract/db_module.c @@ -1,99 +1,78 @@ -#include #include +#include #include #include #include -#include "vm.h" #include "sqlcheck.h" -#include "bignum_module.h" -#include "util.h" +#include "db_module.h" +#include "linkedlist.h" #include "_cgo_export.h" -#define LAST_ERROR(L,db,rc) \ - do { \ - if ((rc) != SQLITE_OK) { \ - luaL_error((L), sqlite3_errmsg((db))); \ - } \ - } while(0) - -#define RESOURCE_PSTMT_KEY "_RESOURCE_PSTMT_KEY_" -#define RESOURCE_RS_KEY "_RESOURCE_RS_KEY_" - -extern int getLuaExecContext(lua_State *L); -static void get_column_meta(lua_State *L, sqlite3_stmt* stmt); - -static int append_resource(lua_State *L, const char *key, void *data) { - int refno; - if (luaL_findtable(L, LUA_REGISTRYINDEX, key, 0) != NULL) { - luaL_error(L, "cannot find the environment of the db module"); - } - /* tab */ - lua_pushlightuserdata(L, data); /* tab pstmt */ - refno = luaL_ref(L, -2); /* tab */ - lua_pop(L, 1); /* remove tab */ - return refno; -} - -#define DB_PSTMT_ID "__db_pstmt__" - -typedef struct { +typedef struct stmt_t stmt_t; +struct stmt_t{ + stmt_t *next; + int id; sqlite3 *db; sqlite3_stmt *s; int closed; - int refno; -} db_pstmt_t; - -#define DB_RS_ID "__db_rs__" +}; -typedef struct { +typedef struct rs_t rs_t; +struct rs_t{ + rs_t *next; + int id; sqlite3 *db; sqlite3_stmt *s; int closed; int nc; int shared_stmt; char **decltypes; - int refno; -} db_rs_t; +}; -static db_rs_t *get_db_rs(lua_State *L, int pos) { - db_rs_t *rs = luaL_checkudata(L, pos, DB_RS_ID); - if (rs->closed) { - luaL_error(L, "resultset is closed"); +// list of stmt_t +stmt_t *pstmt_list = NULL; +// list of rs_t +rs_t *rs_list = NULL; + +int last_id = 0; + + +static void *malloc_zero(request *req, size_t size) { + void *ptr = malloc(size); + if (ptr == NULL) { + set_error(req, "out of memory"); + return NULL; } - return rs; + memset(ptr, 0, size); + return ptr; } -static int db_rs_tostr(lua_State *L) { - db_rs_t *rs = luaL_checkudata(L, 1, DB_RS_ID); - if (rs->closed) { - lua_pushfstring(L, "resultset is closed"); - } else { - lua_pushfstring(L, "resultset{handle=%p}", rs->s); +static int get_next_id() { + return ++last_id; +} + +static rs_t *get_rs(int id) { + rs_t *rs = rs_list; + while (rs != NULL && rs->id != id) { + rs = rs->next; } - return 1; + return rs; } static char *dup_decltype(const char *decltype) { - int n; - char *p; - char *c; - + char *p, *c; if (decltype == NULL) { return NULL; } - p = c = malloc(strlen(decltype)+1); - while ((*c++ = tolower(*decltype++))); - - if (strcmp(p, "date") == 0 || strcmp(p, "datetime") == 0 || strcmp(p, "timestamp") == 0 || - strcmp(p, "boolean") == 0) { - return p; + if (p == NULL) { + return NULL; } - free(p); - return NULL; + while ((*c++ = tolower(*decltype++))); + return p; } -static void free_decltypes(db_rs_t *rs) { +static void free_decltypes(rs_t *rs) { int i; for (i = 0; i < rs->nc; i++) { if (rs->decltypes[i] != NULL) { @@ -104,61 +83,68 @@ static void free_decltypes(db_rs_t *rs) { rs->decltypes = NULL; } -static int db_rs_get(lua_State *L) { - db_rs_t *rs = get_db_rs(L, 1); - int i; +void handle_rs_get(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + rs_t *rs; + int query_id, i; sqlite3_int64 d; double f; int n; const unsigned char *s; - if (rs->decltypes == NULL) { - luaL_error(L, "`get' called without calling `next'"); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; } + + query_id = get_int(&args, 1); + rs = get_rs(query_id); + if (rs == NULL || rs->decltypes == NULL) { + set_error(req, "'get' called without calling 'next'"); + return; + } + if (sqlite3_column_count(rs->s) != rs->nc) { + set_error(req, "column count mismatch - expected %d got %d", rs->nc, sqlite3_column_count(rs->s)); + return; + } + for (i = 0; i < rs->nc; i++) { switch (sqlite3_column_type(rs->s, i)) { case SQLITE_INTEGER: d = sqlite3_column_int64(rs->s, i); - if (rs->decltypes[i] == NULL) { - lua_pushinteger(L, d); - } else if (strcmp(rs->decltypes[i], "boolean") == 0) { - if (d != 0) { - lua_pushboolean(L, 1); - } else { - lua_pushboolean(L, 0); - } - } else { // date, datetime, timestamp + char *decltype = rs->decltypes[i]; + if (decltype && strcmp(decltype, "boolean") == 0) { + add_bool(&req->result, d != 0); + } else if (decltype && + (strcmp(decltype, "date") == 0 || + strcmp(decltype, "datetime") == 0 || + strcmp(decltype, "timestamp") == 0)) { char buf[80]; strftime(buf, 80, "%Y-%m-%d %H:%M:%S", gmtime((time_t *)&d)); - lua_pushlstring(L, (const char *)buf, strlen(buf)); + add_string(&req->result, buf); + } else { + add_int64(&req->result, d); } break; case SQLITE_FLOAT: f = sqlite3_column_double(rs->s, i); - lua_pushnumber(L, f); + add_double(&req->result, f); break; case SQLITE_TEXT: n = sqlite3_column_bytes(rs->s, i); s = sqlite3_column_text(rs->s, i); - lua_pushlstring(L, (const char *)s, n); + add_string(&req->result, s); break; case SQLITE_NULL: /* fallthrough */ default: /* unsupported types */ - lua_pushnil(L); + add_null(&req->result); } } - return rs->nc; -} - -static int db_rs_colcnt(lua_State *L) { - db_rs_t *rs = get_db_rs(L, 1); - lua_pushinteger(L, rs->nc); - return 1; } -static void db_rs_close(lua_State *L, db_rs_t *rs, int remove) { +static void rs_close(rs_t *rs, int remove) { if (rs->closed) { return; } @@ -170,136 +156,128 @@ static void db_rs_close(lua_State *L, db_rs_t *rs, int remove) { sqlite3_finalize(rs->s); } if (remove) { - if (luaL_findtable(L, LUA_REGISTRYINDEX, RESOURCE_RS_KEY, 0) != NULL) { - luaL_error(L, "cannot find the environment of the db module"); - } - luaL_unref(L, -1, rs->refno); - lua_pop(L, 1); + llist_remove(&rs_list, rs); } } -static int db_rs_next(lua_State *L) { - db_rs_t *rs = get_db_rs(L, 1); - int rc; +void handle_rs_next(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + int query_id, rc; + rs_t *rs; + + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; + } + + query_id = get_int(&args, 1); + rs = get_rs(query_id); + if (!rs) { + set_error(req, "invalid query id"); + return; + } rc = sqlite3_step(rs->s); - if (rc == SQLITE_DONE) { - db_rs_close(L, rs, 1); - lua_pushboolean(L, 0); - } else if (rc != SQLITE_ROW) { - rc = sqlite3_reset(rs->s); - LAST_ERROR(L, rs->db, rc); - db_rs_close(L, rs, 1); - lua_pushboolean(L, 0); + + if (rc == SQLITE_ROW) { + add_bool(&req->result, true); + } else if (rc == SQLITE_DONE) { + add_bool(&req->result, false); + rs_close(rs, 1); } else { - if (rs->decltypes == NULL) { - int i; - rs->decltypes = malloc(sizeof(char *) * rs->nc); - for (i = 0; i < rs->nc; i++) { - rs->decltypes[i] = dup_decltype(sqlite3_column_decltype(rs->s, i)); - } + rs_close(rs, 1); + if (rc != SQLITE_OK) { + set_error(req, sqlite3_errmsg(rs->db)); + } else { + set_error(req, "unknown error"); } - lua_pushboolean(L, 1); } - return 1; -} -static int db_rs_gc(lua_State *L) { - db_rs_close(L, luaL_checkudata(L, 1, DB_RS_ID), 1); - return 0; } -static db_pstmt_t *get_db_pstmt(lua_State *L, int pos) { - db_pstmt_t *pstmt = luaL_checkudata(L, pos, DB_PSTMT_ID); - if (pstmt->closed) { - luaL_error(L, "prepared statement is closed"); +static void process_columns(request *req, sqlite3_stmt *stmt, rs_t *rs) { + + int column_count = sqlite3_column_count(stmt); + + rs->nc = column_count; + rs->decltypes = malloc(sizeof(char *) * column_count); + if (rs->decltypes == NULL) { + set_error(req, "out of memory"); + return; } - return pstmt; + + for (int i = 0; i < column_count; i++) { + char *decltype = dup_decltype(sqlite3_column_decltype(stmt, i)); + rs->decltypes[i] = decltype; + } + + add_int(&req->result, column_count); + } -static int db_pstmt_tostr(lua_State *L) { - db_pstmt_t *pstmt = luaL_checkudata(L, 1, DB_PSTMT_ID); - if (pstmt->closed) { - lua_pushfstring(L, "prepared statement is closed"); - } else { - lua_pushfstring(L, "prepared statement{handle=%p}", pstmt->s); +static stmt_t *get_pstmt(int id) { + stmt_t *pstmt = pstmt_list; + while (pstmt != NULL && pstmt->id != id) { + pstmt = pstmt->next; } - return 1; + return pstmt; } -static int bind(lua_State *L, sqlite3 *db, sqlite3_stmt *pstmt) { +static int bind_parameters(request *req, sqlite3 *db, sqlite3_stmt *pstmt, bytes *params) { int rc, i; - int argc = lua_gettop(L) - 1; - int param_count; + int param_count = get_count(params); + int bind_count; - param_count = sqlite3_bind_parameter_count(pstmt); - if (argc != param_count) { - lua_pushfstring(L, "parameter count mismatch: want %d got %d", param_count, argc); + bind_count = sqlite3_bind_parameter_count(pstmt); + if (param_count != bind_count) { + set_error(req, "parameter count mismatch: want %d got %d", bind_count, param_count); return -1; } + if (param_count == 0) { + return 0; + } + rc = sqlite3_reset(pstmt); sqlite3_clear_bindings(pstmt); if (rc != SQLITE_ROW && rc != SQLITE_OK && rc != SQLITE_DONE) { - lua_pushfstring(L, sqlite3_errmsg(db)); + set_error(req, sqlite3_errmsg(db)); return -1; } - for (i = 1; i <= argc; i++) { - int t, b, n = i + 1; - const char *s; - size_t l; - - luaL_checkany(L, n); - t = lua_type(L, n); - - switch (t) { - case LUA_TNUMBER: - if (luaL_isinteger(L, n)) { - lua_Integer d = lua_tointeger(L, n); - rc = sqlite3_bind_int64(pstmt, i, (sqlite3_int64)d); - } else { - lua_Number d = lua_tonumber(L, n); - rc = sqlite3_bind_double(pstmt, i, (double)d); - } + i = 0; + char *ptr = NULL; + int len; + while (ptr = get_next_item(params, ptr, &len)) { + char type = get_type(ptr, len); + ptr += 1; len -= 1; + i++; + switch (type) { + case 'l': + rc = sqlite3_bind_int64(pstmt, i, read_int64(ptr)); + break; + case 'd': + rc = sqlite3_bind_double(pstmt, i, read_double(ptr)); break; - case LUA_TSTRING: - s = lua_tolstring(L, n, &l); - rc = sqlite3_bind_text(pstmt, i, s, l, SQLITE_TRANSIENT); + case 's': + rc = sqlite3_bind_text(pstmt, i, ptr, len-1, SQLITE_TRANSIENT); break; - case LUA_TBOOLEAN: - b = lua_toboolean(L, i+1); - if (b) { + case 'b': + if (read_bool(ptr)) { rc = sqlite3_bind_int(pstmt, i, 1); } else { rc = sqlite3_bind_int(pstmt, i, 0); } break; - case LUA_TNIL: + case 'n': rc = sqlite3_bind_null(pstmt, i); break; - case LUA_TUSERDATA: - { - if (lua_isbignumber(L, n)) { - long int d = lua_get_bignum_si(L, n); - if (d == 0 && lua_bignum_is_zero(L, n) != 0) { - char *s = lua_get_bignum_str(L, n); - if (s != NULL) { - lua_pushfstring(L, "bignum value overflow for binding %s", s); - free(s); - } - return -1; - } - rc = sqlite3_bind_int64(pstmt, i, (sqlite3_int64)d); - break; - } - } default: - lua_pushfstring(L, "unsupported type: %s", lua_typename(L, n)); + set_error(req, "unsupported type: %c", type); return -1; } if (rc != SQLITE_OK) { - lua_pushfstring(L, sqlite3_errmsg(db)); + set_error(req, sqlite3_errmsg(db)); return -1; } } @@ -307,342 +285,413 @@ static int bind(lua_State *L, sqlite3 *db, sqlite3_stmt *pstmt) { return 0; } -static int db_pstmt_exec(lua_State *L) { - int rc, n; - db_pstmt_t *pstmt = get_db_pstmt(L, 1); +void handle_stmt_exec(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + bytes params; + int pstmt_id, rc; + stmt_t *pstmt; + bool success; - /*check for exec in function */ - if (luaCheckView(getLuaExecContext(L)) > 0) { - luaL_error(L, "not permitted in view function"); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; } - rc = bind(L, pstmt->db, pstmt->s); + if (luaIsView(req->service)) { + set_error(req, "not permitted in view function"); + return; + } + + pstmt_id = get_int(&args, 1); + success = get_bytes(&args, 2, ¶ms); + if (!success) { + set_error(req, "invalid parameters"); //FIXME: remove later + return; + } + + pstmt = get_pstmt(pstmt_id); + if (pstmt == NULL) { + set_error(req, "invalid pstmt id"); + return; + } + + rc = bind_parameters(req, pstmt->db, pstmt->s, ¶ms); if (rc == -1) { sqlite3_reset(pstmt->s); sqlite3_clear_bindings(pstmt->s); - luaL_error(L, lua_tostring(L, -1)); + return; } + rc = sqlite3_step(pstmt->s); if (rc != SQLITE_ROW && rc != SQLITE_OK && rc != SQLITE_DONE) { + set_error(req, sqlite3_errmsg(pstmt->db)); sqlite3_reset(pstmt->s); sqlite3_clear_bindings(pstmt->s); - luaL_error(L, sqlite3_errmsg(pstmt->db)); + return; } - n = sqlite3_changes(pstmt->db); - lua_pushinteger(L, n); - return 1; + + add_int64(&req->result, sqlite3_changes(pstmt->db)); } -static int db_pstmt_query(lua_State *L) { - int rc; - db_pstmt_t *pstmt = get_db_pstmt(L, 1); - db_rs_t *rs; +void handle_stmt_query(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + bytes params; + int pstmt_id, rc; + stmt_t *pstmt; + rs_t *rs; + bool success; - getLuaExecContext(L); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; + } + + pstmt_id = get_int(&args, 1); + success = get_bytes(&args, 2, ¶ms); + if (!success) { + set_error(req, "invalid parameters"); //FIXME: remove later + return; + } + + pstmt = get_pstmt(pstmt_id); + if (pstmt == NULL) { + set_error(req, "invalid pstmt id"); + return; + } if (!sqlite3_stmt_readonly(pstmt->s)) { - luaL_error(L, "invalid sql command(permitted readonly)"); + set_error(req, "invalid sql command (only read permitted)"); + return; } - rc = bind(L, pstmt->db, pstmt->s); + + rc = bind_parameters(req, pstmt->db, pstmt->s, ¶ms); if (rc != 0) { sqlite3_reset(pstmt->s); sqlite3_clear_bindings(pstmt->s); - luaL_error(L, lua_tostring(L, -1)); + return; } - rs = (db_rs_t *) lua_newuserdata(L, sizeof(db_rs_t)); - luaL_getmetatable(L, DB_RS_ID); - lua_setmetatable(L, -2); + rs = (rs_t *) malloc_zero(req, sizeof(rs_t)); + if (rs == NULL) { + return; + } + rs->id = get_next_id(); rs->db = pstmt->db; rs->s = pstmt->s; rs->closed = 0; - rs->nc = sqlite3_column_count(pstmt->s); rs->shared_stmt = 1; - rs->decltypes = NULL; - rs->refno = append_resource(L, RESOURCE_RS_KEY, (void *)rs); + llist_add(&rs_list, rs); + + add_int(&req->result, rs->id); + + process_columns(req, pstmt->s, rs); - return 1; } -static void get_column_meta(lua_State *L, sqlite3_stmt* stmt) { +static void get_column_meta(request *req, sqlite3_stmt* stmt) { + buffer names = {0}; + buffer types = {0}; const char *name, *decltype; - int type; int colcnt = sqlite3_column_count(stmt); int i; - lua_createtable(L, 0, 2); - lua_pushinteger(L, colcnt); - lua_setfield(L, -2, "colcnt"); - if (colcnt > 0) { - lua_createtable(L, colcnt, 0); /* colinfos names */ - lua_createtable(L, colcnt, 0); /* colinfos names decltypes */ - } else { - lua_pushnil(L); - lua_pushnil(L); - } - for (i = 0; i < colcnt; i++) { name = sqlite3_column_name(stmt, i); if (name == NULL) { - lua_pushstring(L, ""); + add_string(&names, ""); } else { - lua_pushstring(L, name); + add_string(&names, name); } - lua_rawseti(L, -3, i+1); decltype = sqlite3_column_decltype(stmt, i); if (decltype == NULL) { - lua_pushstring(L, ""); + add_string(&types, ""); } else { - lua_pushstring(L, decltype); + add_string(&types, decltype); } - lua_rawseti(L, -2, i+1); } - lua_setfield(L, -3, "decltypes"); - lua_setfield(L, -2, "names"); + add_bytes(&req->result, names.ptr, names.len); + add_bytes(&req->result, types.ptr, types.len); + free(names.ptr); + free(types.ptr); } -static int db_pstmt_column_info(lua_State *L) { - int colcnt; - db_pstmt_t *pstmt = get_db_pstmt(L, 1); - getLuaExecContext(L); - - get_column_meta(L, pstmt->s); - return 1; -} +void handle_stmt_column_info(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + int pstmt_id; + stmt_t *pstmt; -static int db_pstmt_bind_param_cnt(lua_State *L) { - db_pstmt_t *pstmt = get_db_pstmt(L, 1); - getLuaExecContext(L); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; + } - lua_pushinteger(L, sqlite3_bind_parameter_count(pstmt->s)); + pstmt_id = get_int(&args, 1); + pstmt = get_pstmt(pstmt_id); + if (pstmt == NULL) { + set_error(req, "invalid pstmt id"); + return; + } - return 1; + get_column_meta(req, pstmt->s); } -static void db_pstmt_close(lua_State *L, db_pstmt_t *pstmt, int remove) { +static void stmt_close(stmt_t *pstmt, int remove) { if (pstmt->closed) { return; } pstmt->closed = 1; sqlite3_finalize(pstmt->s); if (remove) { - if (luaL_findtable(L, LUA_REGISTRYINDEX, RESOURCE_PSTMT_KEY, 0) != NULL) { - luaL_error(L, "cannot find the environment of the db module"); - } - luaL_unref(L, -1, pstmt->refno); - lua_pop(L, 1); + llist_remove(&pstmt_list, pstmt); } } -static int db_pstmt_gc(lua_State *L) { - db_pstmt_close(L, luaL_checkudata(L, 1, DB_PSTMT_ID), 1); - return 0; -} - -static int db_exec(lua_State *L) { - const char *cmd; +void handle_db_exec(request *req, char *args_ptr, int args_len) { sqlite3 *db; sqlite3_stmt *s; + bytes args = {args_ptr, args_len}; + bytes params; + char *sql; int rc; + bool success; - /*check for exec in function */ - if (luaCheckView(getLuaExecContext(L))> 0) { - luaL_error(L, "not permitted in view function"); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; } - cmd = luaL_checkstring(L, 1); - if (!sqlcheck_is_permitted_sql(cmd)) { - lua_pushfstring(L, "invalid sql commond:" LUA_QS, cmd); - lua_error(L); + if (luaIsView(req->service)) { + set_error(req, "not permitted in view function"); + return; + } + + sql = get_string(&args, 1); + success = get_bytes(&args, 2, ¶ms); + if (!success) { + set_error(req, "invalid parameters"); //FIXME: remove later + return; } - db = vm_get_db(L); - rc = sqlite3_prepare_v2(db, cmd, -1, &s, NULL); - LAST_ERROR(L, db, rc); - rc = bind(L, db, s); + if (!sqlcheck_is_permitted_sql(sql)) { + set_error(req, "invalid sql command: %s", sql); + return; + } + + db = vm_get_db(req); + if (db == NULL) { + // error already set by vm_get_db + return; + } + + rc = sqlite3_prepare_v2(db, sql, -1, &s, NULL); + if (rc != SQLITE_OK) { + set_error(req, sqlite3_errmsg(db)); + return; + } + + rc = bind_parameters(req, db, s, ¶ms); if (rc == -1) { sqlite3_finalize(s); - luaL_error(L, lua_tostring(L, -1)); + return; } rc = sqlite3_step(s); if (rc != SQLITE_ROW && rc != SQLITE_OK && rc != SQLITE_DONE) { + set_error(req, sqlite3_errmsg(db)); sqlite3_finalize(s); - luaL_error(L, sqlite3_errmsg(db)); + return; } sqlite3_finalize(s); - lua_pushinteger(L, sqlite3_changes(db)); - return 1; + add_int64(&req->result, sqlite3_changes(db)); + } -static int db_query(lua_State *L) { - const char *query; +void handle_db_query(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + bytes params; + char *sql; int rc; sqlite3 *db; sqlite3_stmt *s; - db_rs_t *rs; + rs_t *rs; + bool success; - getLuaExecContext(L); - query = luaL_checkstring(L, 1); - if (!sqlcheck_is_readonly_sql(query)) { - luaL_error(L, "invalid sql command(permitted readonly)"); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; + } + sql = get_string(&args, 1); + success = get_bytes(&args, 2, ¶ms); + if (!success) { + set_error(req, "invalid parameters"); //FIXME: remove later + return; + } + if (!sqlcheck_is_readonly_sql(sql)) { + set_error(req, "invalid sql command (only read permitted)"); + return; } - db = vm_get_db(L); - rc = sqlite3_prepare_v2(db, query, -1, &s, NULL); - LAST_ERROR(L, db, rc); - rc = bind(L, db, s); + db = vm_get_db(req); + if (db == NULL) { + // error already set by vm_get_db + return; + } + + rc = sqlite3_prepare_v2(db, sql, -1, &s, NULL); + if (rc != SQLITE_OK) { + set_error(req, sqlite3_errmsg(db)); + return; + } + + rc = bind_parameters(req, db, s, ¶ms); if (rc == -1) { sqlite3_finalize(s); - luaL_error(L, lua_tostring(L, -1)); + return; } - rs = (db_rs_t *) lua_newuserdata(L, sizeof(db_rs_t)); - luaL_getmetatable(L, DB_RS_ID); - lua_setmetatable(L, -2); + rs = (rs_t *) malloc_zero(req, sizeof(rs_t)); + if (rs == NULL) { + sqlite3_finalize(s); + set_error(req, "out of memory"); + return; + } + rs->id = get_next_id(); rs->db = db; rs->s = s; rs->closed = 0; - rs->nc = sqlite3_column_count(s); rs->shared_stmt = 0; - rs->decltypes = NULL; - rs->refno = append_resource(L, RESOURCE_RS_KEY, (void *)rs); + llist_add(&rs_list, rs); + + add_int(&req->result, rs->id); + + process_columns(req, s, rs); - return 1; } -static int db_prepare(lua_State *L) { - const char *sql; +void handle_db_prepare(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + char *sql; int rc; - int ref; sqlite3 *db; sqlite3_stmt *s; - db_pstmt_t *pstmt; + stmt_t *pstmt; - sql = luaL_checkstring(L, 1); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; + } + sql = get_string(&args, 1); if (!sqlcheck_is_permitted_sql(sql)) { - lua_pushfstring(L, "invalid sql commond:" LUA_QS, sql); - lua_error(L); + set_error(req, "invalid sql command: %s", sql); + return; } - db = vm_get_db(L); + + db = vm_get_db(req); + if (db == NULL) { + // error already set by vm_get_db + return; + } + rc = sqlite3_prepare_v2(db, sql, -1, &s, NULL); - LAST_ERROR(L, db, rc); + if (rc != SQLITE_OK) { + set_error(req, sqlite3_errmsg(db)); + return; + } - pstmt = (db_pstmt_t *) lua_newuserdata(L, sizeof(db_pstmt_t)); - luaL_getmetatable(L, DB_PSTMT_ID); - lua_setmetatable(L, -2); + pstmt = (stmt_t *) malloc_zero(req, sizeof(stmt_t)); + if (pstmt == NULL) { + sqlite3_finalize(s); + set_error(req, "out of memory"); + return; + } + pstmt->id = get_next_id(); pstmt->db = db; pstmt->s = s; pstmt->closed = 0; - pstmt->refno = append_resource(L, RESOURCE_PSTMT_KEY, (void *)pstmt); + llist_add(&pstmt_list, pstmt); - return 1; -} + add_int(&req->result, pstmt->id); + add_int(&req->result, sqlite3_bind_parameter_count(pstmt->s)); -static int db_get_snapshot(lua_State *L) { - char *snapshot; - int service = getLuaExecContext(L); +} - snapshot = LuaGetDbSnapshot(service); - strPushAndRelease(L, snapshot); +sqlite3 *vm_get_db(request *req) { + sqlite3 *db; + db = luaGetDbHandle(req->service); + if (db == NULL) { + set_error(req, "can't open a connection to the contract's database"); + } + return db; +} - return 1; +void handle_db_get_snapshot(request *req) { + char *snapshot; + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; + } + snapshot = LuaGetDbSnapshot(req->service); + add_string(&req->result, snapshot); } -static int db_open_with_snapshot(lua_State *L) { - char *snapshot = (char *) luaL_checkstring(L, 1); +void handle_db_open_with_snapshot(request *req, char *args_ptr, int args_len) { + bytes args = {args_ptr, args_len}; + char *snapshot; char *errStr; - int service = getLuaExecContext(L); - errStr = LuaGetDbHandleSnap(service, snapshot); - if (errStr != NULL) { - strPushAndRelease(L, errStr); - luaL_throwerror(L); + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; } - return 1; -} -static int db_last_insert_rowid(lua_State *L) { - sqlite3 *db; - sqlite3_int64 id; - db = vm_get_db(L); + snapshot = get_string(&args, 1); + if (snapshot == NULL) { + set_error(req, "invalid snapshot"); + return; + } - id = sqlite3_last_insert_rowid(db); - lua_pushinteger(L, id); - return 1; + errStr = LuaGetDbHandleSnap(req->service, snapshot); + if (errStr != NULL) { + set_error(req, errStr); + free(errStr); + return; + } + + add_string(&req->result, "ok"); } -int lua_db_release_resource(lua_State *L) { - lua_getfield(L, LUA_REGISTRYINDEX, RESOURCE_RS_KEY); - if (lua_istable(L, -1)) { - /* T */ - lua_pushnil(L); /* T nil(key) */ - while (lua_next(L, -2)) { - if (lua_islightuserdata(L, -1)) { - db_rs_close(L, (db_rs_t *) lua_topointer(L, -1), 0); - } - lua_pop(L, 1); - } - lua_pop(L, 1); - } - lua_getfield(L, LUA_REGISTRYINDEX, RESOURCE_PSTMT_KEY); - if (lua_istable(L, -1)) { - /* T */ - lua_pushnil(L); /* T nil(key) */ - while (lua_next(L, -2)) { - if (lua_islightuserdata(L, -1)) { - db_pstmt_close(L, (db_pstmt_t *) lua_topointer(L, -1), 0); - } - lua_pop(L, 1); - } - lua_pop(L, 1); +void handle_last_insert_rowid(request *req) { + if (!checkDbExecContext(req->service)) { + set_error(req, "invalid db context"); + return; } - return 0; + sqlite3 *db = vm_get_db(req); + if (db == NULL) { + // error already set by vm_get_db + return; + } + sqlite3_int64 id = sqlite3_last_insert_rowid(db); + add_int64(&req->result, id); } -static const luaL_Reg rs_methods[] = { - {"next", db_rs_next}, - {"get", db_rs_get}, - {"colcnt", db_rs_colcnt}, - {"__tostring", db_rs_tostr}, - {"__gc", db_rs_gc}, - {NULL, NULL} -}; - -static const luaL_Reg pstmt_methods[] = { - {"exec", db_pstmt_exec}, - {"query", db_pstmt_query}, - {"column_info", db_pstmt_column_info}, - {"bind_param_cnt", db_pstmt_bind_param_cnt}, - {"__tostring", db_pstmt_tostr}, - {"__gc", db_pstmt_gc}, - {NULL, NULL} -}; - -static const luaL_Reg db_lib[] = { - {"exec", db_exec}, - {"query", db_query}, - {"prepare", db_prepare}, - {"getsnap", db_get_snapshot}, - {"open_with_snapshot", db_open_with_snapshot}, - {"last_insert_rowid", db_last_insert_rowid}, - {NULL, NULL} -}; - -int luaopen_db(lua_State *L) { +void db_release_resource() { - luaL_newmetatable(L, DB_RS_ID); - lua_pushvalue(L, -1); - lua_setfield(L, -2, "__index"); - luaL_register(L, NULL, rs_methods); - - luaL_newmetatable(L, DB_PSTMT_ID); - lua_pushvalue(L, -1); - lua_setfield(L, -2, "__index"); - luaL_register(L, NULL, pstmt_methods); + rs_t *rs = rs_list, *rs_next; + while (rs != NULL) { + rs_next = rs->next; + rs_close(rs, 0); + free(rs); + rs = rs_next; + } + rs_list = NULL; - luaL_register(L, "db", db_lib); + stmt_t *pstmt = pstmt_list, *pstmt_next; + while (pstmt != NULL) { + pstmt_next = pstmt->next; + stmt_close(pstmt, 0); + free(pstmt); + pstmt = pstmt_next; + } + pstmt_list = NULL; - lua_pop(L, 3); - return 1; } diff --git a/contract/db_module.h b/contract/db_module.h index c8b425719..534ee09c3 100644 --- a/contract/db_module.h +++ b/contract/db_module.h @@ -1,9 +1,26 @@ #ifndef _DB_MODULE_H #define _DB_MODULE_H -#include "lua.h" +#include +#include -extern int luaopen_db(lua_State *L); -extern int lua_db_release_resource(lua_State *L); +#include "sqlite3-binding.h" +#include "db_msg.h" + +sqlite3 *vm_get_db(request *req); + +void handle_db_exec(request *req, char *args_ptr, int args_len); +void handle_db_query(request *req, char *args_ptr, int args_len); +void handle_db_prepare(request *req, char *args_ptr, int args_len); +void handle_db_get_snapshot(request *req); +void handle_db_open_with_snapshot(request *req, char *args_ptr, int args_len); +void handle_last_insert_rowid(request *req); +void handle_stmt_exec(request *req, char *args_ptr, int args_len); +void handle_stmt_query(request *req, char *args_ptr, int args_len); +void handle_stmt_column_info(request *req, char *args_ptr, int args_len); +void handle_rs_get(request *req, char *args_ptr, int args_len); +void handle_rs_next(request *req, char *args_ptr, int args_len); + +void db_release_resource(); #endif /* _DB_MODULE_H */ diff --git a/contract/db_msg.c b/contract/db_msg.c new file mode 100644 index 000000000..fa21bbb28 --- /dev/null +++ b/contract/db_msg.c @@ -0,0 +1,425 @@ +#include +#include +#include +#include +#include +#include +#include "db_msg.h" + +void set_error(request *req, const char *format, ...) { + + if (req == NULL) { + return; // avoid null pointer dereference + } + + va_list args; + va_start(args, format); + + // determine the required buffer size + va_list args_copy; + va_copy(args_copy, args); + int size = vsnprintf(NULL, 0, format, args_copy) + 1; // +1 for null terminator + va_end(args_copy); + + if (size <= 0) { + va_end(args); + return; // error in formatting + } + + // allocate memory for the new error message + char *new_error = malloc(size); + if (new_error == NULL) { + va_end(args); + return; // memory allocation failed + } + + // format the error message + vsnprintf(new_error, size, format, args); + va_end(args); + + // free the old error message if it exists + if (req->error != NULL) { + free(req->error); + } + // set the new error message + req->error = new_error; +} + +// serialization + +// copy int32 to buffer, stored as little endian, for unaligned access +void write_int(char *pdest, int value) { + unsigned char *source = (unsigned char *) &value; + unsigned char *dest = (unsigned char *) pdest; + dest[0] = source[0]; + dest[1] = source[1]; + dest[2] = source[2]; + dest[3] = source[3]; +} + +// read_int32, stored as little endian, for unaligned access +int read_int(char *p) { + int value; + unsigned char *source = (unsigned char *) p; + unsigned char *dest = (unsigned char *) &value; + dest[0] = source[0]; + dest[1] = source[1]; + dest[2] = source[2]; + dest[3] = source[3]; + return value; +} + +// read_int64, stored as little endian, for unaligned access +int64_t read_int64(char *p) { + int64_t value; + unsigned char *source = (unsigned char *) p; + unsigned char *dest = (unsigned char *) &value; + dest[0] = source[0]; + dest[1] = source[1]; + dest[2] = source[2]; + dest[3] = source[3]; + dest[4] = source[4]; + dest[5] = source[5]; + dest[6] = source[6]; + dest[7] = source[7]; + return value; +} + +double read_double(char *p) { + double value; + unsigned char *source = (unsigned char *) p; + unsigned char *dest = (unsigned char *) &value; + dest[0] = source[0]; + dest[1] = source[1]; + dest[2] = source[2]; + dest[3] = source[3]; + dest[4] = source[4]; + dest[5] = source[5]; + dest[6] = source[6]; + dest[7] = source[7]; + return value; +} + +//////////////////////////////////////////////////////////////////////////////// +// add item + +// add item with 4 bytes length +void add_item(buffer *buf, const char *data, int len) { + int item_size = 4 + len; + if (item_size > buf->allocated) { + // compute new size + int new_size = buf->allocated; + if (new_size == 0) { + new_size = 1024; + } + while (new_size < buf->len + item_size) { + new_size *= 2; + } + // reallocate buffer + buf->allocated = new_size; + buf->ptr = (char *)realloc(buf->ptr, buf->allocated); + if (buf->ptr == NULL) { + // TODO: error handling + } + } + // store the length of the item + //*(int *)(req->result + req->used_size) = len; + write_int(buf->ptr + buf->len, len); + // copy item to buffer + memcpy(buf->ptr + buf->len + 4, data, len); + buf->len += item_size; +} + +// now adding an additional byte for type +void add_typed_item(buffer *buf, char type, const char *data, int len) { + int item_size = 4 + 1 + len; + if (item_size > buf->allocated) { + // compute new size + int new_size = buf->allocated; + if (new_size == 0) { + new_size = 1024; + } + while (new_size < buf->len + item_size) { + new_size *= 2; + } + // reallocate buffer + buf->allocated = new_size; + buf->ptr = (char *)realloc(buf->ptr, buf->allocated); + if (buf->ptr == NULL) { + // TODO: error handling + } + } + // store the length of the item + write_int(buf->ptr + buf->len, len + 1); + // store the type of the item + buf->ptr[buf->len + 4] = type; + // copy item to buffer + memcpy(buf->ptr + buf->len + 5, data, len); + buf->len += item_size; +} + +// add items with type + +void add_string(buffer *buf, const char *str) { + if (str == NULL) str = ""; + add_typed_item(buf, 's', str, strlen(str) + 1); +} + +void add_string_ex(buffer *buf, const char *str, int len) { + if (str == NULL) str = ""; + add_typed_item(buf, 's', str, len + 1); +} + +void add_int(buffer *buf, int value) { + add_typed_item(buf, 'i', (char *)&value, 4); +} + +void add_int64(buffer *buf, int64_t value) { + add_typed_item(buf, 'l', (char *)&value, 8); +} + +void add_double(buffer *buf, double value) { + add_typed_item(buf, 'd', (char *)&value, 8); +} + +void add_bool(buffer *buf, bool value) { + add_typed_item(buf, 'b', (char *)&value, 1); +} + +void add_bytes(buffer *buf, const char *data, int len) { + add_typed_item(buf, 'y', data, len); +} + +void add_null(buffer *buf) { + add_typed_item(buf, 'n', NULL, 0); +} + +//////////////////////////////////////////////////////////////////////////////// +// read item + +// get item at position +char *get_item(bytes *data, int position, int *plen) { + char *p = data->ptr; + char *plimit = data->ptr + data->len; + int len; + int count = 1; + + if (p == NULL || position <= 0) { + return NULL; + } + + while (count < position) { + if (plimit - p < 4) { + return NULL; + } + len = read_int(p); + p += 4; + p += len; + count++; + } + + if (plimit - p < 4) { + return NULL; + } + len = read_int(p); + p += 4; + if (p + len > plimit) { + return NULL; + } + if (plen != NULL) *plen = len; + return p; +} + +int get_count(bytes *data) { + int count = 0; + int len; + char *p = data->ptr; + while (p < data->ptr + data->len) { + len = read_int(p); + p += 4; + p += len; + count++; + } + return count; +} + +// get string at position +char *get_string(bytes *data, int position) { + int len; + char *p = get_item(data, position, &len); + if (p == NULL || len < 2 || *p != 's') { + return NULL; + } + // skip type + p++; + return p; +} + +// get int at position +int get_int(bytes *data, int position) { + int len; + char *p = get_item(data, position, &len); + if (p == NULL || len != 1+4) { + return 0; + } + // check type + if (*p != 'i') { + return 0; + } + // skip type + p++; + return read_int(p); +} + +// get int64 at position +int64_t get_int64(bytes *data, int position) { + int64_t value; + int len; + char *p = get_item(data, position, &len); + if (p == NULL || len != 1+8) { + return 0; + } + // check type + if (*p != 'l') { + return 0; + } + // skip type + p++; + memcpy(&value, p, 8); + return value; +} + +// get double at position +double get_double(bytes *data, int position) { + double value; + int len; + char *p = get_item(data, position, &len); + if (p == NULL || len != 1+8) { + return 0; + } + // check type + if (*p != 'd') { + return 0; + } + // skip type + p++; + memcpy(&value, p, 8); + return value; +} + +// get bool at position +bool get_bool(bytes *data, int position) { + int len; + char *p = get_item(data, position, &len); + if (p == NULL || len != 1+1) { + return false; + } + // check type + if (*p != 'b') { + return false; + } + // skip type + p++; + return *p; +} + +bool get_bytes(bytes *data, int position, bytes *pbytes) { + int len; + char *p = get_item(data, position, &len); + if (p == NULL || len < 1 || *p != 'y') { + return false; + } + p++; len--; + pbytes->ptr = p; + pbytes->len = len; + return true; +} + +//////////////////////////////////////////////////////////////////////////////// +// iterate over items + +// get the next item +char *get_next_item(bytes *data, char *pdata, int *plen) { + char *plimit = data->ptr + data->len; + int len; + + if (pdata == NULL) { + *plen = read_int(data->ptr); + return data->ptr + 4; + } + + if (pdata < data->ptr + 4 || pdata > plimit) { + return NULL; + } + + // skip this data + pdata += *plen; + // check if there is more data + if (plimit - pdata < 4) { + *plen = 0; + return NULL; + } + // get the length of the next item + len = read_int(pdata); + // skip the length + pdata += 4; + // check if there is more data + if (pdata + len > plimit) { + *plen = 0; + return NULL; + } + // return the length of the next item + *plen = len; + // return the pointer to the next item + return pdata; +} + +char get_type(char *ptr, int len) { + char type = *ptr; + switch (type) { + case 'i': + if (len != 1+4) { + return 0; + } + break; + case 'l': + if (len != 1+8) { + return 0; + } + break; + case 'd': + if (len != 1+8) { + return 0; + } + break; + case 'b': + if (len != 1+1) { + return 0; + } + } + return type; +} + +bool read_bool(char *p) { + return *p; +} + +//////////////////////////////////////////////////////////////////////////////// + +void free_buffer(buffer *buf) { + if (buf->ptr != NULL) { + free(buf->ptr); + buf->ptr = NULL; + } +} + +void free_response(rresponse *resp) { + if(resp->result.ptr != NULL) { + free(resp->result.ptr); + resp->result.ptr = NULL; + } + if (resp->error != NULL) { + free(resp->error); + resp->error = NULL; + } +} diff --git a/contract/db_msg.h b/contract/db_msg.h new file mode 100644 index 000000000..c4106e55e --- /dev/null +++ b/contract/db_msg.h @@ -0,0 +1,62 @@ +#ifndef DB_MSG_H +#define DB_MSG_H + +typedef struct { + char *ptr; + int len; + int allocated; +} buffer; + +typedef struct { + char *ptr; + int len; +} bytes; + +typedef struct { + int service; + buffer result; + char *error; +} request; + +typedef struct { + bytes result; + char *error; +} rresponse; + + +void set_error(request *req, const char *format, ...); +void write_int(char *pdest, int value); +int read_int(char *p); +int64_t read_int64(char *p); +double read_double(char *p); + +void add_item(buffer *buf, const char *data, int len); +void add_typed_item(buffer *buf, char type, const char *data, int len); +void add_string(buffer *buf, const char *str); +void add_string_ex(buffer *buf, const char *str, int len); +void add_int(buffer *buf, int value); +void add_int64(buffer *buf, int64_t value); +void add_double(buffer *buf, double value); +void add_bool(buffer *buf, bool value); +void add_bytes(buffer *buf, const char *data, int len); +void add_null(buffer *buf); + +char *get_item(bytes *data, int position, int *plen); +int get_count(bytes *data); +char *get_string(bytes *data, int position); +char *get_string_ex(bytes *data, int position, int *plen); +int get_int(bytes *data, int position); +int64_t get_int64(bytes *data, int position); +double get_double(bytes *data, int position); +bool get_bool(bytes *data, int position); +bool get_bytes(bytes *data, int position, bytes *pbytes); +char *get_next_item(bytes *data, char *pdata, int *plen); +char get_type(char *ptr, int len); +double read_double(char *p); +bool read_bool(char *p); + +void free_buffer(buffer *buf); +void free_response(rresponse *resp); + + +#endif // DB_MSG_H diff --git a/contract/db_msg_test.c.disabled b/contract/db_msg_test.c.disabled new file mode 100644 index 000000000..a5c050045 --- /dev/null +++ b/contract/db_msg_test.c.disabled @@ -0,0 +1,192 @@ +#include +#include +#include +#include +#include +#include +#include "db_msg.c" + +#define TEST(name) void test_##name(void) +#define RUN_TEST(name) do { \ + printf("Running %s...", #name); \ + test_##name(); \ + printf(" PASSED\n"); \ +} while(0) + +#define ASSERT(condition) do { \ + if (!(condition)) { \ + printf(" FAILED\n"); \ + printf("Assertion failed: %s\n", #condition); \ + exit(1); \ + } \ +} while(0) + +#define ASSERT_EQUAL(expected, actual) ASSERT((expected) == (actual)) +#define ASSERT_STRING_EQUAL(expected, actual) ASSERT(strcmp((expected), (actual)) == 0) +#define ASSERT_DOUBLE_EQUAL(expected, actual, epsilon) ASSERT(fabs((expected) - (actual)) < (epsilon)) + +TEST(set_error) { + request req = {0}; + set_error(&req, "Error %d", 42); + ASSERT_STRING_EQUAL("Error 42", req.error); +} + +TEST(write_and_read_int) { + char buffer[4]; + int value = 12345; + write_int(buffer, value); + int result = read_int(buffer); + ASSERT_EQUAL(value, result); +} + +TEST(read_int64) { + char buffer[8] = {0x78, 0x56, 0x34, 0x12, 0xF0, 0xDE, 0xBC, 0x9A}; + int64_t result = read_int64(buffer); + ASSERT_EQUAL(0x9ABCDEF012345678, result); +} + +TEST(read_double) { + char buffer[8] = {0x18, 0x2D, 0x44, 0x54, 0xFB, 0x21, 0x09, 0x40}; + double result = read_double(buffer); + ASSERT_DOUBLE_EQUAL(3.141592653589793, result, 1e-15); +} + +TEST(add_and_get_string) { + buffer buf = {0}; + add_string(&buf, "Hello, World!"); + + bytes data = {buf.ptr, buf.len}; + char* result = get_string(&data, 1); + ASSERT_STRING_EQUAL("Hello, World!", result); + + free_buffer(&buf); +} + +TEST(add_and_get_int) { + buffer buf = {0}; + add_int(&buf, 42); + + bytes data = {buf.ptr, buf.len}; + int result = get_int(&data, 1); + ASSERT_EQUAL(42, result); + + free_buffer(&buf); +} + +TEST(add_and_get_int64) { + buffer buf = {0}; + add_int64(&buf, 1234567890123456789LL); + + bytes data = {buf.ptr, buf.len}; + int64_t result = get_int64(&data, 1); + ASSERT_EQUAL(1234567890123456789LL, result); + + free_buffer(&buf); +} + +TEST(add_and_get_double) { + buffer buf = {0}; + add_double(&buf, 3.14159); + + bytes data = {buf.ptr, buf.len}; + double result = get_double(&data, 1); + ASSERT_DOUBLE_EQUAL(3.14159, result, 1e-5); + + free_buffer(&buf); +} + +TEST(add_and_get_bool) { + buffer buf = {0}; + add_bool(&buf, true); + add_bool(&buf, false); + + bytes data = {buf.ptr, buf.len}; + bool result1 = get_bool(&data, 1); + bool result2 = get_bool(&data, 2); + ASSERT(result1); + ASSERT(!result2); + + free_buffer(&buf); +} + +TEST(add_and_get_bytes) { + buffer buf = {0}; + const char data[] = {0x01, 0x02, 0x03, 0x04}; + add_bytes(&buf, data, sizeof(data)); + + bytes data_bytes = {buf.ptr, buf.len}; + bytes result; + bool success = get_bytes(&data_bytes, 1, &result); + ASSERT(success); + ASSERT(result.len == sizeof(data)); + ASSERT(memcmp(result.ptr, data, sizeof(data)) == 0); + + free_buffer(&buf); +} + +TEST(get_count) { + buffer buf = {0}; + add_string(&buf, "One"); + add_int(&buf, 2); + add_double(&buf, 3.0); + + bytes data = {buf.ptr, buf.len}; + int count = get_count(&data); + ASSERT_EQUAL(3, count); + + free_buffer(&buf); +} + +TEST(get_next_item) { + buffer buf = {0}; + add_int(&buf, 123); + add_string(&buf, "Hello"); + add_double(&buf, 3.14159); + add_bool(&buf, true); + + bytes data = {buf.ptr, buf.len}; + int len; + char* item = get_next_item(&data, NULL, &len); + ASSERT_EQUAL('i', get_type(item,len)); + item++; len--; + ASSERT_EQUAL(4, len); + ASSERT_EQUAL(123, read_int(item)); + + item = get_next_item(&data, item, &len); + ASSERT_EQUAL('s', get_type(item,len)); + item++; len--; + ASSERT_STRING_EQUAL("Hello", item); + + item = get_next_item(&data, item, &len); + ASSERT_EQUAL('d', get_type(item,len)); + item++; len--; + ASSERT_DOUBLE_EQUAL(3.14159, read_double(item), 1e-5); + + item = get_next_item(&data, item, &len); + ASSERT_EQUAL('b', get_type(item,len)); + item++; len--; + ASSERT(read_bool(item)); + + item = get_next_item(&data, item, &len); + ASSERT(item == NULL); + + free_buffer(&buf); +} + +int main(void) { + RUN_TEST(set_error); + RUN_TEST(write_and_read_int); + RUN_TEST(read_int64); + RUN_TEST(read_double); + RUN_TEST(add_and_get_string); + RUN_TEST(add_and_get_int); + RUN_TEST(add_and_get_int64); + RUN_TEST(add_and_get_double); + RUN_TEST(add_and_get_bool); + RUN_TEST(add_and_get_bytes); + RUN_TEST(get_count); + RUN_TEST(get_next_item); + + printf("All tests passed!\n"); + return 0; +} diff --git a/contract/linkedlist.c b/contract/linkedlist.c new file mode 100644 index 000000000..697be5f3c --- /dev/null +++ b/contract/linkedlist.c @@ -0,0 +1,98 @@ +// correct use: +// llist_add(&first, item); +// llist_prepend(&first, item); +// llist_remove(&first, item); +// llist_count(first); +// llist_get(first, pos); + +#include +#include "linkedlist.h" + +typedef struct llitem llitem; +struct llitem { + llitem *next; +}; + +void llist_add(void *pfirst, void *pto_add) { + llitem **first, *to_add, *item; + + first = (llitem **) pfirst; + to_add = (llitem *) pto_add; + + item = *first; + if (item == 0) { + *first = to_add; + } else { + while (item->next != 0) { + item = item->next; + } + item->next = to_add; + } + +} + +void llist_prepend(void *pfirst, void *pto_add) { + llitem **first, *to_add, *item; + + first = (llitem **) pfirst; + to_add = (llitem *) pto_add; + + item = *first; + *first = to_add; + to_add->next = item; + +} + +/* safer version: other threads can be iterating the list while item is removed */ +/* caller should not release the memory immediately */ +void llist_safe_remove(void *pfirst, void *pto_del) { + llitem **first, *to_del, *item; + + first = (llitem **) pfirst; + to_del = (llitem *) pto_del; + + item = *first; + if (to_del == item) { + *first = to_del->next; + } else { + while (item != NULL) { + if (item->next == to_del) { + item->next = to_del->next; + break; + } + item = item->next; + } + } + +} + +void llist_remove(void *pfirst, void *pto_del) { + llitem *to_del = (llitem *) pto_del; + llist_safe_remove(pfirst, pto_del); + to_del->next = NULL; /* unsafe for concurrent threads without mutex */ +} + +int llist_count(void *list) { + llitem *item = (llitem *) list; + int count = 0; + + while (item) { + count++; + item = item->next; + } + + return count; +} + +void* llist_get(void *list, int pos) { + llitem *item = (llitem *) list; + int count = 0; + + while (item) { + if (count==pos) return item; + count++; + item = item->next; + } + + return NULL; +} diff --git a/contract/linkedlist.h b/contract/linkedlist.h new file mode 100644 index 000000000..a3ed9334e --- /dev/null +++ b/contract/linkedlist.h @@ -0,0 +1,10 @@ +#ifndef LINKEDLIST_H +#define LINKEDLIST_H + +void llist_add(void *pfirst, void *pto_add); +void llist_prepend(void *pfirst, void *pto_add); +void llist_remove(void *pfirst, void *pto_del); +int llist_count(void *list); +void* llist_get(void *list, int pos); + +#endif diff --git a/contract/lstate_factory.go b/contract/lstate_factory.go deleted file mode 100644 index a35f28091..000000000 --- a/contract/lstate_factory.go +++ /dev/null @@ -1,111 +0,0 @@ -package contract - -/* -#include -#include "bignum_module.h" -#include "vm.h" -*/ -import "C" -import ( - "sync" -) - -var maxLStates int -var getCh chan *LState -var freeCh chan *LState -var once sync.Once - -func StartLStateFactory(numLStates, numClosers, numCloseLimit int) { - once.Do(func() { - C.init_bignum() - C.initViewFunction() - - maxLStates = numLStates - getCh = make(chan *LState, numLStates) - freeCh = make(chan *LState, numLStates) - - for i := 0; i < numLStates; i++ { - getCh <- newLState() - } - - for i := 0; i < numClosers; i++ { - go statePool(numCloseLimit) - } - }) -} - -func statePool(numCloseLimit int) { - s := newLStatesBuffer(numCloseLimit) - - for { - select { - case state := <-freeCh: - s.append(state) - getCh <- newLState() - } - } -} - -func GetLState() *LState { - state := <-getCh - ctrLgr.Trace().Msg("LState acquired") - return state -} - -func FreeLState(state *LState) { - if state != nil { - freeCh <- state - ctrLgr.Trace().Msg("LState released") - } -} - -func FlushLStates() { - for i := 0; i < maxLStates; i++ { - s := GetLState() - FreeLState(s) - } -} - -//--------------------------------------------------------------------// -// LState type - -type LState = C.struct_lua_State - -func newLState() *LState { - ctrLgr.Trace().Msg("LState created") - return C.vm_newstate(C.int(currentForkVersion)) -} - -func (L *LState) close() { - if L != nil { - C.lua_close(L) - } -} - -type lStatesBuffer struct { - s []*LState - limit int -} - -func newLStatesBuffer(limit int) *lStatesBuffer { - return &lStatesBuffer{ - s: make([]*LState, 0), - limit: limit, - } -} - -func (Ls *lStatesBuffer) len() int { - return len(Ls.s) -} - -func (Ls *lStatesBuffer) append(s *LState) { - Ls.s = append(Ls.s, s) - if Ls.len() == Ls.limit { - Ls.close() - } -} - -func (Ls *lStatesBuffer) close() { - C.vm_closestates(&Ls.s[0], C.int(len(Ls.s))) - Ls.s = Ls.s[:0] -} diff --git a/contract/msg/vm_messages.go b/contract/msg/vm_messages.go new file mode 100644 index 000000000..be5e7d733 --- /dev/null +++ b/contract/msg/vm_messages.go @@ -0,0 +1,145 @@ +package msg + +import ( + "bytes" + "strings" + "encoding/binary" + "io" + "net" + "time" + "errors" + "fmt" +) + +// SerializeMessage serializes a variable number of strings into a byte slice +func SerializeMessage(strings ...string) []byte { + var buf bytes.Buffer + + // write number of strings + binary.Write(&buf, binary.LittleEndian, uint32(len(strings))) + + // write each string's length and content + for _, s := range strings { + length := uint32(len(s)) + binary.Write(&buf, binary.LittleEndian, length) + buf.WriteString(s) + } + + return buf.Bytes() +} + +func SerializeMessageBytes(args ...[]byte) []byte { + var buf bytes.Buffer + + // write number of strings + binary.Write(&buf, binary.LittleEndian, uint32(len(args))) + + // write each string's length and content + for _, b := range args { + length := uint32(len(b)) + binary.Write(&buf, binary.LittleEndian, length) + buf.Write(b) + } + + return buf.Bytes() +} + +// DeserializeMessage deserializes a byte slice into an array of strings +func DeserializeMessage(data []byte) ([]string, error) { + var strings []string + buf := bytes.NewReader(data) + + // read number of strings + var numStrings uint32 + if err := binary.Read(buf, binary.LittleEndian, &numStrings); err != nil { + return nil, err + } + + // read each string's length and content without making unnecessary copies, + // by creating a slice that references the original buffer + for i := uint32(0); i < numStrings; i++ { + var length uint32 + if err := binary.Read(buf, binary.LittleEndian, &length); err != nil { + return nil, err + } + + // get the current position + pos, err := buf.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + + // create a slice that references the original buffer + strBytes := data[pos : pos+int64(length)] + buf.Seek(int64(length), io.SeekCurrent) // move the buffer position forward + + strings = append(strings, string(strBytes)) + } + + return strings, nil +} + +func SendMessage(conn net.Conn, message []byte) (err error) { + + // send the length prefix + lengthBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(lengthBytes, uint32(len(message))) + _, err = conn.Write(lengthBytes) + if err != nil { + return err + } + + // send the message + _, err = conn.Write(message) + if err != nil { + return err + } + + // attempt to flush the connection + if flusher, ok := conn.(interface{ Flush() error }); ok { + err = flusher.Flush() + if err != nil { + return fmt.Errorf("error flushing connection: %w", err) + } + } + + return nil +} + +// waits for a full message (length prefix + data) from the abstract domain socket +func WaitForMessage(conn net.Conn, deadline time.Time) (msg []byte, err error) { + + if !deadline.IsZero() { + // define the deadline for the connection + conn.SetReadDeadline(deadline) + } + + // read the length prefix + length := make([]byte, 4) + n, err := io.ReadFull(conn, length) + if err != nil { + if err == io.EOF && n == 0 { + return nil, fmt.Errorf("connection closed before reading message length") + } + if strings.Contains(err.Error(), "i/o timeout") { + return nil, errors.New("contract timeout during vm execution") + } + return nil, fmt.Errorf("error reading message length (read %d bytes): %w", n, err) + } + + // use little endian to get the length + msgLength := binary.LittleEndian.Uint32(length) + + // read the message + msg = make([]byte, msgLength) + n, err = io.ReadFull(conn, msg) + if err != nil { + if strings.Contains(err.Error(), "i/o timeout") { + return nil, errors.New("contract timeout during vm execution") + } + return nil, fmt.Errorf("error reading message body (read %d/%d bytes): %w", n, msgLength, err) + } + + // return the message + return msg, nil +} diff --git a/contract/statesql.go b/contract/statesql.go index 82b54a124..bbcc8232c 100644 --- a/contract/statesql.go +++ b/contract/statesql.go @@ -2,6 +2,7 @@ package contract /* #include "sqlite3-binding.h" +#include "db_module.h" */ import "C" import ( @@ -13,6 +14,7 @@ import ( "os" "path/filepath" "sync" + "runtime" "github.com/aergoio/aergo-lib/log" "github.com/aergoio/aergo/v2/internal/enc/base58" @@ -110,6 +112,7 @@ func LoadTestDatabase(dataDir string) error { func CloseDatabase() { var err error + C.db_release_resource() for name, db := range database.DBs { if db.tx != nil { err = db.tx.rollback() @@ -122,8 +125,10 @@ func CloseDatabase() { if err != nil { sqlLgr.Warn().Err(err).Str("db_name", name).Msg("SQL DB close") } - delete(database.DBs, name) } + database.DBs = make(map[string]*litetree) + // now the routine can migrate to another thread + runtime.UnlockOSThread() } func SaveRecoveryPoint(bs *state.BlockState) error { @@ -137,21 +142,21 @@ func SaveRecoveryPoint(bs *state.BlockState) error { sqlLgr.Warn().Err(err).Str("db_name", id).Msg("SQL TX commit") continue } - rp := db.recoveryPoint() - if rp == 0 { + lastRP := db.getLastRecoveryPoint() + if lastRP == 0 { return ErrFindRp } - if rp > 0 { + if lastRP > 0 { if sqlLgr.IsDebugEnabled() { - sqlLgr.Debug().Str("db_name", id).Uint64("commit_id", rp).Msg("save recovery point") + sqlLgr.Debug().Str("db_name", id).Uint64("commit_id", lastRP).Msg("save recovery point") } - receiverState, err := bs.GetAccountState(db.accountID) + contractState, err := bs.GetAccountState(db.accountID) if err != nil { return err } - receiverChange := receiverState.Clone() - receiverChange.SqlRecoveryPoint = uint64(rp) - err = bs.PutState(db.accountID, receiverChange) + updatedContractState := contractState.Clone() + updatedContractState.SqlRecoveryPoint = lastRP + err = bs.PutState(db.accountID, updatedContractState) if err != nil { return err } @@ -162,7 +167,7 @@ func SaveRecoveryPoint(bs *state.BlockState) error { } func beginTx(dbName string, rp uint64) (sqlTx, error) { - db, err := conn(dbName) + db, err := getDBConnection(dbName) defer func() { if err != nil { delete(database.DBs, dbName) @@ -202,7 +207,7 @@ func beginReadOnly(dbName string, rp uint64) (sqlTx, error) { return newReadOnlySqlTx(db, rp) } -func conn(dbName string) (*litetree, error) { +func getDBConnection(dbName string) (*litetree, error) { if db, ok := database.DBs[dbName]; ok { return db, nil } @@ -304,7 +309,7 @@ type branchInfo struct { TotalCommits uint64 `json:"total_commits"` } -func (db *litetree) recoveryPoint() uint64 { +func (db *litetree) getLastRecoveryPoint() uint64 { row := db.QueryRowContext(context.Background(), "pragma branch_info(master)") var rv string err := row.Scan(&rv) @@ -320,7 +325,7 @@ func (db *litetree) recoveryPoint() uint64 { } func (db *litetree) restoreRecoveryPoint(stateRp uint64) error { - lastRp := db.recoveryPoint() + lastRp := db.getLastRecoveryPoint() if sqlLgr.IsDebugEnabled() { sqlLgr.Debug().Str("db_name", db.name). Uint64("state_rp", stateRp). diff --git a/contract/vm.go b/contract/vm.go index 11d5985a5..8f14a6369 100644 --- a/contract/vm.go +++ b/contract/vm.go @@ -5,41 +5,25 @@ package contract -/* - #cgo CFLAGS: -I${SRCDIR}/../libtool/include/luajit-2.1 -I${SRCDIR}/../libtool/include - #cgo !windows CFLAGS: -DLJ_TARGET_POSIX - #cgo darwin LDFLAGS: ${SRCDIR}/../libtool/lib/libluajit-5.1.a ${SRCDIR}/../libtool/lib/libgmp.dylib -lm - #cgo windows LDFLAGS: ${SRCDIR}/../libtool/lib/libluajit-5.1.a ${SRCDIR}/../libtool/bin/libgmp-10.dll -lm - #cgo !darwin,!windows LDFLAGS: ${SRCDIR}/../libtool/lib/libluajit-5.1.a -L${SRCDIR}/../libtool/lib64 -L${SRCDIR}/../libtool/lib -lgmp -lm - - - #include - #include - #include "vm.h" - #include "bignum_module.h" -*/ -import "C" import ( "bytes" "context" + "encoding/binary" "encoding/json" + "strconv" "errors" "fmt" "math/big" "math/rand" "os" - "reflect" - "sort" "strings" "sync" "time" - "unsafe" "github.com/aergoio/aergo-lib/log" luacUtil "github.com/aergoio/aergo/v2/cmd/aergoluac/util" + "github.com/aergoio/aergo/v2/contract/msg" "github.com/aergoio/aergo/v2/fee" - "github.com/aergoio/aergo/v2/internal/enc/base58" - "github.com/aergoio/aergo/v2/internal/enc/hex" "github.com/aergoio/aergo/v2/state" "github.com/aergoio/aergo/v2/state/statedb" "github.com/aergoio/aergo/v2/types" @@ -49,11 +33,11 @@ import ( ) const ( - callMaxInstLimit = C.int(5000000) - queryMaxInstLimit = callMaxInstLimit * C.int(10) + callMaxInstLimit = 5000000 + queryMaxInstLimit = callMaxInstLimit * 10 dbUpdateMaxLimit = fee.StateDbMaxUpdateSize maxCallDepthOld = 5 - maxCallDepth = 64 + maxCallDepth = 20 checkFeeDelegationFn = "check_delegation" constructor = "constructor" @@ -66,7 +50,7 @@ var ( contexts []*vmContext lastQueryIndex int querySync sync.Mutex - currentForkVersion int32 + CurrentForkVersion int32 ) type ChainAccessor interface { @@ -87,10 +71,10 @@ type vmContext struct { node string confirmed bool isQuery bool - nestedView int32 // indicates which parent called the contract in view (read-only mode) + nestedView int32 // indicates whether the parent called the contract in view (read-only) mode isFeeDelegation bool isMultiCall bool - service C.int + service int callState map[types.AccountID]*callState lastRecoveryEntry *recoveryEntry dbUpdateTotalSize int64 @@ -98,24 +82,27 @@ type vmContext struct { events []*types.Event eventCount int32 callDepth int32 + callStack []*executor traceFile *os.File gasLimit uint64 - remainedGas uint64 + remainingGas uint64 execCtx context.Context + deadline time.Time } type executor struct { - L *LState + vmInstance *VmInstance code []byte err error - numArgs C.int ci *types.CallInfo fname string ctx *vmContext + contractGasLimit uint64 + usedGas uint64 jsonRet string isView bool isAutoload bool - preErr error + abiErr error } func MaxCallDepth(version int32) int32 { @@ -125,6 +112,10 @@ func MaxCallDepth(version int32) int32 { return maxCallDepthOld } +func MaxPossibleCallDepth() int { + return maxCallDepth +} + func init() { ctrLgr = log.NewLogger("contract") lastQueryIndex = ChainService @@ -166,9 +157,9 @@ func NewVmContext( confirmed: confirmed, isQuery: query, blockInfo: bi, - service: C.int(executionMode), + service: executionMode, gasLimit: gasLimit, - remainedGas: gasLimit, + remainingGas: gasLimit, isFeeDelegation: feeDelegation, isMultiCall: isMultiCall, execCtx: execCtx, @@ -184,6 +175,11 @@ func NewVmContext( ctx.traceFile = getTraceFile(ctx.blockInfo.No, txHash) } + // use the deadline from the execution context + if deadline, ok := execCtx.Deadline(); ok { + ctx.deadline = deadline + } + return ctx } @@ -222,35 +218,45 @@ func (ctx *vmContext) IsMultiCall() bool { return ctx.isMultiCall } +//////////////////////////////////////////////////////////////////////////////// +// GAS +//////////////////////////////////////////////////////////////////////////////// + func (ctx *vmContext) IsGasSystem() bool { return fee.GasEnabled(ctx.blockInfo.ForkVersion) && !ctx.isQuery } -// get the remaining gas from the given LState -func (ctx *vmContext) refreshRemainingGas(L *LState) { - if ctx.IsGasSystem() { - ctx.remainedGas = uint64(C.lua_gasget(L)) +// check if the gas limit set by the parent VM instance is valid +func (ctx *vmContext) parseGasLimit(gas string) (uint64, error) { + // it must be a valid uint64 value + if len(gas) != 8 { + return 0, errors.New("uncatchable: invalid gas limit") + } + gasLimit := binary.LittleEndian.Uint64([]byte(gas)) + // gas limit must be less than or equal to the remaining gas + if gasLimit > ctx.remainingGas { + return 0, errors.New("uncatchable: gas limit exceeds the remaining gas") } + return gasLimit, nil } -// set the remaining gas on the given LState -func (ctx *vmContext) setRemainingGas(L *LState) { - if ctx.IsGasSystem() { - C.lua_gasset(L, C.ulonglong(ctx.remainedGas)) +// get the total gas used by all contracts in the current transaction +func (ctx *vmContext) usedGas() uint64 { + if fee.IsZeroFee() || !ctx.IsGasSystem() { + return 0 } + return ctx.gasLimit - ctx.remainingGas } +// get the contracts execution fee func (ctx *vmContext) usedFee() *big.Int { return fee.TxExecuteFee(ctx.blockInfo.ForkVersion, ctx.bs.GasPrice, ctx.usedGas(), ctx.dbUpdateTotalSize) } -func (ctx *vmContext) usedGas() uint64 { - if fee.IsZeroFee() || !ctx.IsGasSystem() { - return 0 - } - return ctx.gasLimit - ctx.remainedGas -} +//////////////////////////////////////////////////////////////////////////////// + +// TODO: is this used on private chains? if not, remove it func (ctx *vmContext) addUpdateSize(updateSize int64) error { if ctx.IsGasSystem() { return nil @@ -262,6 +268,8 @@ func (ctx *vmContext) addUpdateSize(updateSize int64) error { return nil } + + func resolveFunction(contractState *statedb.ContractState, bs *state.BlockState, name string, constructor bool) (*types.Function, error) { abi, err := GetABI(contractState, bs) if err != nil { @@ -285,81 +293,62 @@ func resolveFunction(contractState *statedb.ContractState, bs *state.BlockState, return nil, errors.New("not found function: " + name) } + + + func newExecutor( - contract []byte, + bytecode []byte, contractId []byte, ctx *vmContext, ci *types.CallInfo, amount *big.Int, isCreate bool, - isDelegation bool, + isFeeDelegation bool, ctrState *statedb.ContractState, ) *executor { - if ctx.blockInfo.ForkVersion != currentForkVersion { - // force the StatePool to regenerate the LStates + if ctx.blockInfo.ForkVersion != CurrentForkVersion { + // force the VM Pool to regenerate the VM instances // using the new hardfork version - currentForkVersion = ctx.blockInfo.ForkVersion - FlushLStates() + CurrentForkVersion = ctx.blockInfo.ForkVersion + FlushVmInstances() } + // create a new executor and add it to the call stack + ce := &executor{ + ctx: ctx, + code: bytecode, + } + ctx.callStack = append(ctx.callStack, ce) + ctx.callDepth++ if ctx.callDepth > MaxCallDepth(ctx.blockInfo.ForkVersion) { - ce := &executor{ - code: contract, - ctx: ctx, - } ce.err = fmt.Errorf("exceeded the maximum call depth(%d)", MaxCallDepth(ctx.blockInfo.ForkVersion)) return ce } - ctx.callDepth++ if blacklist.Check(types.EncodeAddress(contractId)) { - ce := &executor{ - code: contract, - ctx: ctx, - } ce.err = fmt.Errorf("contract not available") ctrLgr.Error().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("blocked contract") return ce } - ce := &executor{ - code: contract, - L: GetLState(), - ctx: ctx, - } - if ce.L == nil { + // get a connection to an unused VM instance + ce.vmInstance = GetVmInstance() + if ce.vmInstance == nil { ce.err = ErrVmStart ctrLgr.Error().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("new AergoLua executor") return ce } - if ctx.blockInfo.ForkVersion >= 2 { - C.luaL_set_hardforkversion(ce.L, C.int(ctx.blockInfo.ForkVersion)) - } - - if ctx.IsGasSystem() { - ce.setGas() - defer func() { - ce.refreshRemainingGas() - if ctrLgr.IsDebugEnabled() { - ctrLgr.Debug().Uint64("gas used", ce.ctx.usedGas()).Str("lua vm", "loaded").Msg("gas information") - } - }() - } - - ce.vmLoadCode(contractId) - if ce.err != nil { - return ce - } if isCreate { f, err := resolveFunction(ctrState, ctx.bs, constructor, isCreate) if err != nil { - ce.preErr = err + ce.abiErr = err ctrLgr.Debug().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("not found function") return ce } if f == nil { + // the constructor function does not need to be declared with abi.register() f = &types.Function{ Name: constructor, Payable: false, @@ -367,161 +356,44 @@ func newExecutor( } err = checkPayable(f, amount) if err != nil { - ce.preErr = err + ce.abiErr = err ctrLgr.Debug().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("check payable function") return ce } ce.isView = f.View ce.fname = constructor ce.isAutoload = true - ce.numArgs = C.int(len(ci.Args)) - } else if isDelegation { + } else if isFeeDelegation { _, err := resolveFunction(ctrState, ctx.bs, checkFeeDelegationFn, false) if err != nil { - ce.preErr = err + ce.abiErr = err ctrLgr.Debug().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("not found function") return ce } ce.isView = true ce.fname = checkFeeDelegationFn ce.isAutoload = true - ce.numArgs = C.int(len(ci.Args)) } else { - f, err := resolveFunction(ctrState, ctx.bs, ci.Name, isCreate) + f, err := resolveFunction(ctrState, ctx.bs, ci.Name, false) if err != nil { - ce.preErr = err + ce.abiErr = err ctrLgr.Debug().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("not found function") return ce } err = checkPayable(f, amount) if err != nil { - ce.preErr = err + ce.abiErr = err ctrLgr.Debug().Err(ce.err).Str("contract", types.EncodeAddress(contractId)).Msg("check payable function") return ce } ce.isView = f.View ce.fname = f.Name - ce.numArgs = C.int(len(ci.Args) + 1) } ce.ci = ci return ce } -func (ce *executor) processArgs() { - for _, v := range ce.ci.Args { - if err := pushValue(ce.L, v); err != nil { - ce.err = err - return - } - } -} - -func (ce *executor) getEvents() []*types.Event { - if ce == nil || ce.ctx == nil { - return nil - } - return ce.ctx.events -} - -func pushValue(L *LState, v interface{}) error { - switch arg := v.(type) { - case string: - argC := C.CBytes([]byte(arg)) - C.lua_pushlstring(L, (*C.char)(argC), C.size_t(len(arg))) - C.free(argC) - case float64: - if arg == float64(int64(arg)) { - C.lua_pushinteger(L, C.lua_Integer(arg)) - } else { - C.lua_pushnumber(L, C.double(arg)) - } - case bool: - var b int - if arg { - b = 1 - } - C.lua_pushboolean(L, C.int(b)) - case json.Number: - str := arg.String() - intVal, err := arg.Int64() - if err == nil { - C.lua_pushinteger(L, C.lua_Integer(intVal)) - } else { - ftVal, err := arg.Float64() - if err != nil { - return errors.New("unsupported number type:" + str) - } - C.lua_pushnumber(L, C.double(ftVal)) - } - case nil: - C.lua_pushnil(L) - case []interface{}: - err := toLuaArray(L, arg) - if err != nil { - return err - } - case map[string]interface{}: - err := toLuaTable(L, arg) - if err != nil { - return err - } - default: - return errors.New("unsupported type:" + reflect.TypeOf(v).Name()) - } - return nil -} - -func toLuaArray(L *LState, arr []interface{}) error { - C.lua_createtable(L, C.int(len(arr)), C.int(0)) - n := C.lua_gettop(L) - for i, v := range arr { - if err := pushValue(L, v); err != nil { - return err - } - C.lua_rawseti(L, n, C.int(i+1)) - } - return nil -} - -func toLuaTable(L *LState, tab map[string]interface{}) error { - C.lua_createtable(L, C.int(0), C.int(len(tab))) - n := C.lua_gettop(L) - // get the keys and sort them - keys := make([]string, 0, len(tab)) - for k := range tab { - keys = append(keys, k) - } - if C.vm_is_hardfork(L, 3) { - sort.Strings(keys) - } - for _, k := range keys { - v := tab[k] - if len(tab) == 1 && strings.EqualFold(k, "_bignum") { - if arg, ok := v.(string); ok { - C.lua_settop(L, -2) - argC := C.CString(arg) - msg := C.lua_set_bignum(L, argC) - C.free(unsafe.Pointer(argC)) - if msg != nil { - return errors.New(C.GoString(msg)) - } - return nil - } - } - // push a key - key := C.CString(k) - C.lua_pushstring(L, key) - C.free(unsafe.Pointer(key)) - - if err := pushValue(L, v); err != nil { - return err - } - C.lua_rawset(L, n) - } - return nil -} - func checkPayable(callee *types.Function, amount *big.Int) error { if amount.Cmp(big.NewInt(0)) <= 0 || callee.Payable { return nil @@ -529,115 +401,11 @@ func checkPayable(callee *types.Function, amount *big.Int) error { return fmt.Errorf("'%s' is not payable", callee.Name) } -func (ce *executor) call(instLimit C.int, target *LState) (ret C.int) { - defer func() { - if ret == 0 && target != nil { - if C.luaL_hasuncatchablerror(ce.L) != C.int(0) { - C.luaL_setuncatchablerror(target) - } - if C.luaL_hassyserror(ce.L) != C.int(0) { - C.luaL_setsyserror(target) - } - } - }() - if ce.err != nil { - return 0 - } - defer ce.refreshRemainingGas() - if ce.isView == true { - ce.ctx.nestedView++ - defer func() { - ce.ctx.nestedView-- - }() - } - ce.vmLoadCall() - if ce.err != nil { - return 0 - } - if ce.preErr != nil { - ce.err = ce.preErr - return 0 - } - if ce.isAutoload { - if loaded := vmAutoload(ce.L, ce.fname); !loaded { - if ce.fname != constructor { - ce.err = errors.New(fmt.Sprintf("contract autoload failed %s : %s", - types.EncodeAddress(ce.ctx.curContract.contractId), ce.fname)) - } - return 0 - } - } else { - C.vm_remove_constructor(ce.L) - resolvedName := C.CString(ce.fname) - C.vm_get_abi_function(ce.L, resolvedName) - C.free(unsafe.Pointer(resolvedName)) - } - ce.processArgs() - if ce.err != nil { - ctrLgr.Debug().Err(ce.err).Stringer("contract", - types.LogAddr(ce.ctx.curContract.contractId)).Msg("invalid argument") - return 0 - } - ce.setCountHook(instLimit) - nRet := C.int(0) - cErrMsg := C.vm_pcall(ce.L, ce.numArgs, &nRet) - if cErrMsg != nil { - errMsg := C.GoString(cErrMsg) - if C.luaL_hassyserror(ce.L) != C.int(0) { - ce.err = newVmSystemError(errors.New(errMsg)) - } else { - isUncatchable := C.luaL_hasuncatchablerror(ce.L) != C.int(0) - if isUncatchable && (errMsg == C.ERR_BF_TIMEOUT || errMsg == vmTimeoutErrMsg) { - ce.err = &VmTimeoutError{} - } else { - ce.err = errors.New(errMsg) - } - } - ctrLgr.Debug().Err(ce.err).Stringer( - "contract", - types.LogAddr(ce.ctx.curContract.contractId), - ).Msg("contract is failed") - return 0 - } - if target == nil { - var errRet C.int - retMsg := C.GoString(C.vm_get_json_ret(ce.L, nRet, &errRet)) - if errRet == 1 { - ce.err = errors.New(retMsg) - } else { - ce.jsonRet = retMsg - } - } else { - if c2ErrMsg := C.vm_copy_result(ce.L, target, nRet); c2ErrMsg != nil { - errMsg := C.GoString(c2ErrMsg) - ce.err = errors.New(errMsg) - ctrLgr.Debug().Err(ce.err).Stringer( - "contract", - types.LogAddr(ce.ctx.curContract.contractId), - ).Msg("failed to move results") - } - } - if ce.ctx.traceFile != nil { - address := types.EncodeAddress(ce.ctx.curContract.contractId) - codeFile := fmt.Sprintf("%s%s%s.code", os.TempDir(), string(os.PathSeparator), address) - if _, err := os.Stat(codeFile); os.IsNotExist(err) { - f, err := os.OpenFile(codeFile, os.O_WRONLY|os.O_CREATE, 0644) - if err == nil { - _, _ = f.Write(ce.code) - _ = f.Close() - } - } - _, _ = ce.ctx.traceFile.WriteString(fmt.Sprintf("contract %s used fee: %s\n", - address, ce.ctx.usedFee().String())) +func (ce *executor) getEvents() []*types.Event { + if ce == nil || ce.ctx == nil { + return nil } - return nRet -} - -func vmAutoload(L *LState, funcName string) bool { - s := C.CString(funcName) - loaded := C.vm_autoload(L, s) - C.free(unsafe.Pointer(s)) - return loaded != C.int(0) + return ce.ctx.events } func (ce *executor) commitCalledContract() error { @@ -736,69 +504,22 @@ func (ce *executor) closeQuerySql() error { return nil } -func (ce *executor) setGas() { - if ce == nil || ce.L == nil || ce.err != nil { - return - } - C.lua_gasset(ce.L, C.ulonglong(ce.ctx.remainedGas)) -} - func (ce *executor) close() { if ce != nil { + FreeVmInstance(ce.vmInstance) if ce.ctx != nil { ce.ctx.callDepth-- + ce.ctx.callStack = ce.ctx.callStack[:len(ce.ctx.callStack)-1] if ce.ctx.traceFile != nil { ce.ctx.traceFile.Close() ce.ctx.traceFile = nil } } - if ce.L != nil { - FreeLState(ce.L) - } } } -func (ce *executor) refreshRemainingGas() { - ce.ctx.refreshRemainingGas(ce.L) -} - -func (ce *executor) gas() uint64 { - return uint64(C.lua_gasget(ce.L)) -} -func (ce *executor) vmLoadCode(id []byte) { - var chunkId *C.char - if ce.ctx.blockInfo.ForkVersion >= 3 { - chunkId = C.CString("@" + types.EncodeAddress(id)) - } else { - chunkId = C.CString(hex.Encode(id)) - } - defer C.free(unsafe.Pointer(chunkId)) - if cErrMsg := C.vm_loadbuff( - ce.L, - (*C.char)(unsafe.Pointer(&ce.code[0])), - C.size_t(len(ce.code)), - chunkId, - ce.ctx.service-C.int(maxContext), - ); cErrMsg != nil { - errMsg := C.GoString(cErrMsg) - ce.err = errors.New(errMsg) - ctrLgr.Debug().Err(ce.err).Str("contract", types.EncodeAddress(id)).Msg("failed to load code") - } -} -func (ce *executor) vmLoadCall() { - if cErrMsg := C.vm_loadcall(ce.L); cErrMsg != nil { - errMsg := C.GoString(cErrMsg) - isUncatchable := C.luaL_hasuncatchablerror(ce.L) != C.int(0) - if isUncatchable && (errMsg == C.ERR_BF_TIMEOUT || errMsg == vmTimeoutErrMsg) { - ce.err = &VmTimeoutError{} - } else { - ce.err = errors.New(errMsg) - } - } - C.luaL_set_service(ce.L, ce.ctx.service) -} func getMultiCallInfo(ci *types.CallInfo, payload []byte) error { payload = append([]byte{'['}, payload...) @@ -807,20 +528,35 @@ func getMultiCallInfo(ci *types.CallInfo, payload []byte) error { return getCallInfo(&ci.Args, payload, []byte("multicall")) } +// ci is a pointer to a CallInfo struct: { Name string, Args []interface{} } +// args is a JSON array of arguments func getCallInfo(ci interface{}, args []byte, contractAddress []byte) error { d := json.NewDecoder(bytes.NewReader(args)) d.UseNumber() d.DisallowUnknownFields() err := d.Decode(ci) if err != nil { - ctrLgr.Debug().AnErr("error", err).Str( - "contract", - types.EncodeAddress(contractAddress), - ).Msg("invalid calling information") + ctrLgr.Debug().AnErr("error", err).Str("contract", types.EncodeAddress(contractAddress)).Msg("invalid calling information") } return err } +// return only the arguments as a single string containing the JSON array +func getCallInfoArgs(ci *types.CallInfo) (string, error) { + args, err := json.Marshal(ci.Args) + if err != nil { + return "", err + } + return string(args), nil +} + + + + +//////////////////////////////////////////////////////////////////////////////// +// Called Externally +//////////////////////////////////////////////////////////////////////////////// + func Call( contractState *statedb.ContractState, payload, contractAddress []byte, @@ -860,14 +596,19 @@ func Call( // create a new executor contexts[ctx.service] = ctx ce := newExecutor(bytecode, contractAddress, ctx, &ci, ctx.curContract.amount, false, false, contractState) - defer ce.close() + defer ce.close() // close the executor and the VM instance + + // set the gas limit from the transaction + ce.contractGasLimit = ctx.gasLimit if ce.err == nil { - startTime := time.Now() // execute the contract call - ce.call(callMaxInstLimit, nil) - vmExecTime := time.Now().Sub(startTime).Microseconds() - vmLogger.Trace().Int64("execµs", vmExecTime).Stringer("txHash", types.LogBase58(ce.ctx.txHash)).Msg("tx execute time in vm") + ce.call(true) + // update the total used gas + err = ctx.updateUsedGas(ce.usedGas) + if err != nil { + return "", nil, ctx.usedFee(), err + } } // check if there is an error @@ -925,71 +666,6 @@ func Call( return ce.jsonRet, ce.getEvents(), ctx.usedFee(), nil } -func setRandomSeed(ctx *vmContext) { - var randSrc rand.Source - if ctx.isQuery { - randSrc = rand.NewSource(ctx.blockInfo.Ts) - } else { - b, _ := new(big.Int).SetString(base58.Encode(ctx.blockInfo.PrevBlockHash[:7]), 62) - t, _ := new(big.Int).SetString(base58.Encode(ctx.txHash[:7]), 62) - b.Add(b, t) - randSrc = rand.NewSource(b.Int64()) - } - ctx.seed = rand.New(randSrc) -} - -func setContract(contractState *statedb.ContractState, contractAddress, payload []byte, ctx *vmContext) ([]byte, []byte, error) { - // the payload contains: - // on V3: bytecode + ABI + constructor arguments - // on V4: lua code + constructor arguments - codePayload := luacUtil.LuaCodePayload(payload) - if _, err := codePayload.IsValidFormat(); err != nil { - ctrLgr.Warn().Err(err).Str("contract", types.EncodeAddress(contractAddress)).Msg("deploy") - return nil, nil, err - } - code := codePayload.Code() // type: LuaCode - - var sourceCode []byte - var bytecodeABI []byte - var err error - - // if hardfork version 4 - if ctx.blockInfo.ForkVersion >= 4 { - // the payload must be lua code. compile it to bytecode - sourceCode = code - bytecodeABI, err = Compile(string(sourceCode), nil) - if err != nil { - ctrLgr.Warn().Err(err).Str("contract", types.EncodeAddress(contractAddress)).Msg("deploy") - return nil, nil, err - } - } else { - // on previous hardfork versions the payload is bytecode - bytecodeABI = code - } - - // save the bytecode to the contract state - err = contractState.SetCode(sourceCode, bytecodeABI) - if err != nil { - return nil, nil, err - } - - // extract the bytecode - bytecode := luacUtil.LuaCode(bytecodeABI).ByteCode() - - // check if it was properly stored - savedBytecode := getContractCode(contractState, nil) - if savedBytecode == nil || !bytes.Equal(savedBytecode, bytecode) { - err = fmt.Errorf("cannot deploy contract %s", types.EncodeAddress(contractAddress)) - ctrLgr.Warn().Str("error", "cannot load contract").Str( - "contract", - types.EncodeAddress(contractAddress), - ).Msg("deploy") - return nil, nil, err - } - - return bytecode, codePayload.Args(), nil -} - func Create( contractState *statedb.ContractState, payload, contractAddress []byte, @@ -1028,20 +704,21 @@ func Create( contexts[ctx.service] = ctx - if ctx.blockInfo.ForkVersion < 2 { - // create a sql database for the contract - if db := luaGetDbHandle(ctx.service); db == nil { - return "", nil, ctx.usedFee(), newVmError(errors.New("can't open a database connection")) - } - } - // create a new executor for the constructor ce := newExecutor(bytecode, contractAddress, ctx, &ci, ctx.curContract.amount, true, false, contractState) - defer ce.close() + defer ce.close() // close the executor and the VM instance - if err == nil { + // set the gas limit from the transaction + ce.contractGasLimit = ctx.gasLimit + + if ce.err == nil { // call the constructor - ce.call(callMaxInstLimit, nil) + ce.call(true) + // update the total used gas + err = ctx.updateUsedGas(ce.usedGas) + if err != nil { + return "", nil, ctx.usedFee(), err + } } // check if the call failed @@ -1112,7 +789,7 @@ func allocContextSlot(ctx *vmContext) { index = ChainService + 1 } if contexts[index] == nil { - ctx.service = C.int(index) + ctx.service = index contexts[index] = ctx lastQueryIndex = index return @@ -1162,20 +839,34 @@ func Query(contractAddress []byte, bs *state.BlockState, cdb ChainAccessor, cont } ce := newExecutor(bytecode, contractAddress, ctx, &ci, ctx.curContract.amount, false, false, contractState) - defer ce.close() + defer ce.close() // close the executor and the VM instance defer func() { if dbErr := ce.closeQuerySql(); dbErr != nil { err = dbErr } }() - if err == nil { - ce.call(queryMaxInstLimit, nil) + // set the gas limit from the transaction + ce.contractGasLimit = ctx.gasLimit + + if ce.err == nil { + // execute the contract call + ce.call(true) + // update the total used gas + err = ctx.updateUsedGas(ce.usedGas) + if err != nil { + return nil, err + } } return []byte(ce.jsonRet), ce.err } + + +//! this is complicated, a query before the actual execution +// and queried many times, even by mempool + func CheckFeeDelegation(contractAddress []byte, bs *state.BlockState, bi *types.BlockHeaderInfo, cdb ChainAccessor, contractState *statedb.ContractState, payload, txHash, sender, amount []byte) (err error) { var ci types.CallInfo @@ -1238,15 +929,24 @@ func CheckFeeDelegation(contractAddress []byte, bs *state.BlockState, bi *types. ci.Name = checkFeeDelegationFn ce := newExecutor(bytecode, contractAddress, ctx, &ci, ctx.curContract.amount, false, true, contractState) - defer ce.close() + defer ce.close() // close the executor and the VM instance defer func() { if dbErr := ce.rollbackToSavepoint(); dbErr != nil { err = dbErr } }() - if err == nil { - ce.call(queryMaxInstLimit, nil) + // set the gas limit from the transaction + ce.contractGasLimit = ctx.gasLimit + + if ce.err == nil { + // execute the contract call + ce.call(true) + // update the total used gas + err = ctx.updateUsedGas(ce.usedGas) + if err != nil { + return err + } } if ce.err != nil { @@ -1258,6 +958,72 @@ func CheckFeeDelegation(contractAddress []byte, bs *state.BlockState, bi *types. return nil } +func (ctx *vmContext) updateUsedGas(usedGas uint64) error { + if usedGas > ctx.remainingGas { + ctx.remainingGas = 0 + return errors.New("run out of gas") + } + // deduct the used gas + ctx.remainingGas -= usedGas + return nil +} + + + +//////////////////////////////////////////////////////////////////////////////// +// Contract Code +//////////////////////////////////////////////////////////////////////////////// + +// only called by a deploy transaction +func setContract(contractState *statedb.ContractState, contractAddress, payload []byte, ctx *vmContext) ([]byte, []byte, error) { + // the payload contains: + // on V3: bytecode + ABI + constructor arguments + // on V4: lua code + constructor arguments + codePayload := luacUtil.LuaCodePayload(payload) + if _, err := codePayload.IsValidFormat(); err != nil { + ctrLgr.Warn().Err(err).Str("contract", types.EncodeAddress(contractAddress)).Msg("deploy") + return nil, nil, err + } + code := codePayload.Code() // type: LuaCode + + var sourceCode []byte + var bytecodeABI []byte + var err error + + // if hardfork version 4 + if ctx.blockInfo.ForkVersion >= 4 { + // the payload must be lua code. compile it to bytecode + sourceCode = code + bytecodeABI, err = Compile(string(sourceCode), false) + if err != nil { + ctrLgr.Warn().Err(err).Str("contract", types.EncodeAddress(contractAddress)).Msg("deploy") + return nil, nil, err + } + } else { + // on previous hardfork versions the payload is bytecode + bytecodeABI = code + } + + // save the bytecode to the contract state + err = contractState.SetCode(sourceCode, bytecodeABI) + if err != nil { + return nil, nil, err + } + + // extract the bytecode + bytecode := luacUtil.LuaCode(bytecodeABI).ByteCode() + + // check if it was properly stored + savedBytecode := getContractCode(contractState, nil) + if savedBytecode == nil || !bytes.Equal(savedBytecode, bytecode) { + err = fmt.Errorf("cannot deploy contract %s", types.EncodeAddress(contractAddress)) + ctrLgr.Warn().Str("error", "cannot load contract").Str("contract", types.EncodeAddress(contractAddress)).Msg("deploy") + return nil, nil, err + } + + return bytecode, codePayload.Args(), nil +} + func getCode(contractState *statedb.ContractState, bs *state.BlockState) ([]byte, error) { var code []byte var err error @@ -1308,7 +1074,7 @@ func getMultiCallCode(contractState *statedb.ContractState) []byte { if multicall_compiled == nil { // compile the Lua code used to execute multicall txns var err error - multicall_compiled, err = Compile(multicall_code, nil) + multicall_compiled, err = Compile(multicall_code, false) if err != nil { ctrLgr.Error().Err(err).Msg("multicall compile") return nil @@ -1357,30 +1123,57 @@ func GetABI(contractState *statedb.ContractState, bs *state.BlockState) (*types. return abi, nil } -func Compile(code string, parent *LState) (luacUtil.LuaCode, error) { - L := luacUtil.NewLState() - if L == nil { - return nil, ErrVmStart - } - defer luacUtil.CloseLState(L) - if parent != nil { - var lState = (*LState)(L) - if cErrMsg := C.vm_copy_service(lState, parent); cErrMsg != nil { - if C.luaL_hasuncatchablerror(lState) != C.int(0) { - C.luaL_setuncatchablerror(parent) - } - errMsg := C.GoString(cErrMsg) - return nil, errors.New(errMsg) - } - C.luaL_set_hardforkversion(lState, C.luaL_hardforkversion(parent)) - C.vm_set_timeout_hook(lState) +// send the source code to a VM instance, to be compiled +func Compile(code string, hasParent bool) (luacUtil.LuaCode, error) { + + // get a connection to an unused VM instance + vmInstance := GetVmInstance() + if vmInstance == nil { + err := ErrVmStart + ctrLgr.Error().Err(err).Msg("get vm instance for compilation") + return nil, err } - byteCodeAbi, err := luacUtil.Compile(L, code) + defer FreeVmInstance(vmInstance) + + // build the message + message := msg.SerializeMessage("compile", code, strconv.FormatBool(hasParent)) + + /*/ encrypt the message + message, err = msg.Encrypt(message, secretKey) + if err != nil { + return nil, err + } + */ + + // send the execution request to the VM instance + err := msg.SendMessage(vmInstance.conn, message) + if err != nil { + return nil, fmt.Errorf("compile: send message: %v", err) + } + + // timeout of 250 ms + deadline := time.Now().Add(250 * time.Millisecond) + response, err := msg.WaitForMessage(vmInstance.conn, deadline) + if err != nil { + return nil, fmt.Errorf("compile: wait for message: %v", err) + } + + /*/ decrypt the message + response, err = msg.Decrypt(response, secretKey) if err != nil { - if parent != nil && C.luaL_hasuncatchablerror((*LState)(L)) != C.int(0) { - C.luaL_setuncatchablerror(parent) - } return nil, err } - return byteCodeAbi, nil + */ + + results, err := msg.DeserializeMessage(response) + if len(results) != 2 { + return nil, fmt.Errorf("compile: invalid number of results: %v", results) + } + bytecodeAbi := results[0] + errMsg := results[1] + + if len(errMsg) > 0 { + return nil, fmt.Errorf("compile: %s", errMsg) + } + return luacUtil.LuaCode(bytecodeAbi), nil } diff --git a/contract/vm.h b/contract/vm.h deleted file mode 100644 index e32341a68..000000000 --- a/contract/vm.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _VM_H -#define _VM_H - -#include -#include -#include -#include -#include -#include "sqlite3-binding.h" - -extern const char *construct_name; - -#define FORK_V2 "_FORK_V2" -#define ERR_BF_TIMEOUT "contract timeout" - -lua_State *vm_newstate(int hardfork_version); -void vm_closestates(lua_State *s[], int count); -int vm_autoload(lua_State *L, char *func_name); -void vm_remove_constructor(lua_State *L); -const char *vm_loadbuff(lua_State *L, const char *code, size_t sz, char *hex_id, int service); -const char *vm_pcall(lua_State *L, int argc, int *nresult); -const char *vm_get_json_ret(lua_State *L, int nresult, int *err); -const char *vm_copy_result(lua_State *L, lua_State *target, int cnt); -sqlite3 *vm_get_db(lua_State *L); -void vm_get_abi_function(lua_State *L, char *fname); -void vm_set_count_hook(lua_State *L, int limit); -void vm_db_release_resource(lua_State *L); -bool vm_is_hardfork(lua_State *L, int version); -void initViewFunction(); -void vm_set_timeout_hook(lua_State *L); -void vm_set_timeout_count_hook(lua_State *L, int limit); -int vm_instcount(lua_State *L); -void vm_setinstcount(lua_State *L, int count); -const char *vm_copy_service(lua_State *L, lua_State *main); -const char *vm_loadcall(lua_State *L); - -#endif /* _VM_H */ diff --git a/contract/bignum_module.c b/contract/vm/bignum_module.c similarity index 100% rename from contract/bignum_module.c rename to contract/vm/bignum_module.c diff --git a/contract/bignum_module.h b/contract/vm/bignum_module.h similarity index 100% rename from contract/bignum_module.h rename to contract/vm/bignum_module.h diff --git a/contract/contract_module.c b/contract/vm/contract_module.c similarity index 81% rename from contract/contract_module.c rename to contract/vm/contract_module.c index c022a0a9c..85e407c44 100644 --- a/contract/contract_module.c +++ b/contract/vm/contract_module.c @@ -5,7 +5,7 @@ #include "bignum_module.h" #include "_cgo_export.h" -extern int getLuaExecContext(lua_State *L); +extern void checkLuaExecContext(lua_State *L); static const char *contract_str = "contract"; static const char *call_str = "call"; @@ -90,10 +90,11 @@ static int moduleCall(lua_State *L) { char *fname; char *json_args; struct luaCallContract_return ret; - int service = getLuaExecContext(L); lua_Integer gas; char *amount; + checkLuaExecContext(L); + if (lua_gettop(L) == 2) { lua_gasuse(L, 300); } else { @@ -117,7 +118,8 @@ static int moduleCall(lua_State *L) { lua_pop(L, 2); contract = (char *)luaL_checkstring(L, 2); if (lua_gettop(L) == 2) { - char *errStr = luaSendAmount(L, service, contract, amount); + // when called with contract.call.value(amount)(address) - this triggers a call to the default() function + char *errStr = luaSendAmount(L, contract, amount); reset_amount_info(L); if (errStr != NULL) { strPushAndRelease(L, errStr); @@ -132,16 +134,25 @@ static int moduleCall(lua_State *L) { luaL_throwerror(L); } - ret = luaCallContract(L, service, contract, fname, json_args, amount, gas); + ret = luaCallContract(L, contract, fname, json_args, amount, gas); + free(json_args); + reset_amount_info(L); + + // if it returned an error message, push it to the stack and throw an error if (ret.r1 != NULL) { - free(json_args); - reset_amount_info(L); strPushAndRelease(L, ret.r1); luaL_throwerror(L); } - free(json_args); - reset_amount_info(L); - return ret.r0; + // push the returned values to the stack + int count = lua_util_json_array_to_lua(L, ret.r0, true); + free(ret.r0); + if (count == -1) { + luaL_setuncatchablerror(L); + lua_pushstring(L, "internal error: result from call is not a valid JSON array"); + luaL_throwerror(L); + } + // return the number of items in the stack + return count; } static int delegate_call_gas(lua_State *L) { @@ -153,9 +164,10 @@ static int moduleDelegateCall(lua_State *L) { char *fname; char *json_args; struct luaDelegateCallContract_return ret; - int service = getLuaExecContext(L); lua_Integer gas; + checkLuaExecContext(L); + lua_gasuse(L, 2000); lua_getfield(L, 1, fee_str); @@ -173,26 +185,36 @@ static int moduleDelegateCall(lua_State *L) { reset_amount_info(L); luaL_throwerror(L); } - ret = luaDelegateCallContract(L, service, contract, fname, json_args, gas); + + ret = luaDelegateCallContract(L, contract, fname, json_args, gas); + free(json_args); + reset_amount_info(L); + + // if it returned an error message, push it to the stack and throw an error if (ret.r1 != NULL) { - free(json_args); - reset_amount_info(L); strPushAndRelease(L, ret.r1); luaL_throwerror(L); } - free(json_args); - reset_amount_info(L); - - return ret.r0; + // push the returned values to the stack + int count = lua_util_json_array_to_lua(L, ret.r0, true); + free(ret.r0); + if (count == -1) { + luaL_setuncatchablerror(L); + lua_pushstring(L, "internal error: result from call is not a valid JSON array"); + luaL_throwerror(L); + } + // return the number of items in the stack + return count; } static int moduleSend(lua_State *L) { char *contract; char *errStr; - int service = getLuaExecContext(L); char *amount; bool needfree = false; + checkLuaExecContext(L); + lua_gasuse(L, 300); contract = (char *) luaL_checkstring(L, 1); @@ -218,7 +240,7 @@ static int moduleSend(lua_State *L) { luaL_error(L, "invalid input"); } - errStr = luaSendAmount(L, service, contract, amount); + errStr = luaSendAmount(L, contract, amount); if (needfree) { free(amount); @@ -232,10 +254,11 @@ static int moduleSend(lua_State *L) { static int moduleBalance(lua_State *L) { char *contract; - int service = getLuaExecContext(L); struct luaGetBalance_return balance; int nArg; + checkLuaExecContext(L); + lua_gasuse(L, 300); nArg = lua_gettop(L); @@ -255,7 +278,7 @@ static int moduleBalance(lua_State *L) { if (mode != -1) { struct luaGetStaking_return ret; const char *errMsg; - ret = luaGetStaking(service, contract); + ret = luaGetStaking(L, contract); if (ret.r2 != NULL) { strPushAndRelease(L, ret.r2); luaL_throwerror(L); @@ -273,7 +296,7 @@ static int moduleBalance(lua_State *L) { } } - balance = luaGetBalance(L, service, contract); + balance = luaGetBalance(L, contract); if (balance.r1 != NULL) { strPushAndRelease(L, balance.r1); luaL_throwerror(L); @@ -284,40 +307,42 @@ static int moduleBalance(lua_State *L) { } static int modulePcall(lua_State *L) { - int argc = lua_gettop(L) - 1; - int service = getLuaExecContext(L); - int num_events = luaGetEventCount(L, service); struct luaSetRecoveryPoint_return start_seq; + int argc; int ret; + checkLuaExecContext(L); + + argc = lua_gettop(L) - 1; + lua_gasuse(L, 300); - start_seq = luaSetRecoveryPoint(L, service); + // create a recovery point + start_seq = luaSetRecoveryPoint(L); if (start_seq.r0 < 0) { strPushAndRelease(L, start_seq.r1); luaL_throwerror(L); } - if ((ret = lua_pcall(L, argc, LUA_MULTRET, 0)) != 0) { - // if out of memory, throw error - if (ret == LUA_ERRMEM) { - luaL_throwerror(L); - } - // add 'success = false' as the first returned value - lua_pushboolean(L, false); - lua_insert(L, 1); - // drop the events - if (vm_is_hardfork(L, 4)) { - luaDropEvent(L, service, num_events); - } + // call the function + ret = lua_pcall(L, argc, LUA_MULTRET, 0); + if (ret != 0) { // revert the contract state if (start_seq.r0 > 0) { - char *errStr = luaClearRecovery(L, service, start_seq.r0, true); + char *errStr = luaClearRecovery(L, start_seq.r0, true); if (errStr != NULL) { strPushAndRelease(L, errStr); luaL_throwerror(L); } } + // if out of memory, throw error + if (ret == LUA_ERRMEM) { + luaL_throwerror(L); + } + // add 'success = false' as the first returned value + lua_pushboolean(L, false); + lua_insert(L, 1); + // return the 2 values return 2; } @@ -327,11 +352,8 @@ static int modulePcall(lua_State *L) { // release the recovery point if (start_seq.r0 == 1) { - char *errStr = luaClearRecovery(L, service, start_seq.r0, false); + char *errStr = luaClearRecovery(L, start_seq.r0, false); if (errStr != NULL) { - if (vm_is_hardfork(L, 4)) { - luaDropEvent(L, service, num_events); - } strPushAndRelease(L, errStr); luaL_throwerror(L); } @@ -350,9 +372,10 @@ static int moduleDeploy(lua_State *L) { char *fname; char *json_args; struct luaDeployContract_return ret; - int service = getLuaExecContext(L); char *amount; + checkLuaExecContext(L); + lua_gasuse(L, 5000); // get the amount to transfer to the new contract @@ -372,29 +395,34 @@ static int moduleDeploy(lua_State *L) { luaL_throwerror(L); } - ret = luaDeployContract(L, service, contract, json_args, amount); - if (ret.r0 < 0) { - free(json_args); - reset_amount_info(L); + ret = luaDeployContract(L, contract, json_args, amount); + free(json_args); + reset_amount_info(L); + + // if it returned an error message, push it to the stack and throw an error + if (ret.r1 != NULL) { strPushAndRelease(L, ret.r1); luaL_throwerror(L); } - - free(json_args); - reset_amount_info(L); - strPushAndRelease(L, ret.r1); - if (ret.r0 > 1) { - lua_insert(L, -ret.r0); + // push the returned values to the stack + int count = lua_util_json_array_to_lua(L, ret.r0, true); + free(ret.r0); + if (count == -1) { + luaL_setuncatchablerror(L); + lua_pushstring(L, "internal error: result from call is not a valid JSON array"); + luaL_throwerror(L); } - return ret.r0; + // return the number of items in the stack + return count; } static int moduleEvent(lua_State *L) { char *event_name; char *json_args; - int service = getLuaExecContext(L); char *errStr; + checkLuaExecContext(L); + lua_gasuse(L, 500); event_name = (char *) luaL_checkstring(L, 1); @@ -406,21 +434,23 @@ static int moduleEvent(lua_State *L) { if (json_args == NULL) { luaL_throwerror(L); } - errStr = luaEvent(L, service, event_name, json_args); + + errStr = luaEvent(L, event_name, json_args); + free(json_args); if (errStr != NULL) { strPushAndRelease(L, errStr); luaL_throwerror(L); } - free(json_args); return 0; } static int governance(lua_State *L, char type) { char *ret; - int service = getLuaExecContext(L); char *arg; bool needfree = false; + checkLuaExecContext(L); + lua_gasuse(L, 500); if (type == 'S' || type == 'U') { @@ -455,7 +485,8 @@ static int governance(lua_State *L, char type) { } needfree = true; } - ret = luaGovernance(L, service, type, arg); + + ret = luaGovernance(L, type, arg); if (needfree) { free(arg); } diff --git a/contract/contract_module.h b/contract/vm/contract_module.h similarity index 100% rename from contract/contract_module.h rename to contract/vm/contract_module.h diff --git a/contract/crypto_module.c b/contract/vm/crypto_module.c similarity index 88% rename from contract/crypto_module.c rename to contract/vm/crypto_module.c index 2fa1e68cc..a60727c59 100644 --- a/contract/crypto_module.c +++ b/contract/vm/crypto_module.c @@ -1,7 +1,7 @@ #include "_cgo_export.h" #include "util.h" -extern int getLuaExecContext(lua_State *L); +extern void checkLuaExecContext(lua_State *L); static int crypto_sha256(lua_State *L) { size_t len; @@ -25,7 +25,8 @@ static int crypto_sha256(lua_State *L) { static int crypto_ecverify(lua_State *L) { char *msg, *sig, *addr; struct luaECVerify_return ret; - int service = getLuaExecContext(L); + + checkLuaExecContext(L); lua_gasuse(L, 5000); @@ -36,14 +37,12 @@ static int crypto_ecverify(lua_State *L) { sig = (char *) lua_tostring(L, 2); addr = (char *) lua_tostring(L, 3); - ret = luaECVerify(L, service, msg, sig, addr); + ret = luaECVerify(L, msg, sig, addr); if (ret.r1 != NULL) { strPushAndRelease(L, ret.r1); lua_error(L); } - lua_pushboolean(L, ret.r0); - return 1; } @@ -96,12 +95,13 @@ static struct rlp_obj *makeValue(lua_State *L, int n) { } static int crypto_verifyProof(lua_State *L) { + struct luaCryptoVerifyProof_return ret; int argc = lua_gettop(L); char *k, *h; struct rlp_obj *v; struct proof *proof; size_t kLen, hLen, nProof; - int i, b; + int i; const int proofIndex = 4; lua_gasuse(L, 5000); @@ -121,11 +121,15 @@ static int crypto_verifyProof(lua_State *L) { proof[i-proofIndex].data = (char *) lua_tolstring(L, i, &proof[i-proofIndex].len); } - b = luaCryptoVerifyProof(k, kLen, v, h, hLen, proof, nProof); + ret = luaCryptoVerifyProof(L, k, kLen, v, h, hLen, proof, nProof); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + lua_error(L); + } + if (proof != NULL) { free(proof); } - if (v != NULL) { if (v->rlp_obj_type == RLP_TLIST) { free(v->data); @@ -133,7 +137,7 @@ static int crypto_verifyProof(lua_State *L) { free(v); } - lua_pushboolean(L, b); + lua_pushboolean(L, ret.r0); return 1; } @@ -147,7 +151,11 @@ static int crypto_keccak256(lua_State *L) { luaL_checktype(L, 1, LUA_TSTRING); arg = (char *) lua_tolstring(L, 1, &len); - ret = luaCryptoKeccak256(arg, len); + ret = luaCryptoKeccak256(L, arg, len); + if (ret.r2 != NULL) { + strPushAndRelease(L, ret.r2); + lua_error(L); + } lua_pushlstring(L, ret.r0, ret.r1); free(ret.r0); return 1; diff --git a/contract/crypto_module.h b/contract/vm/crypto_module.h similarity index 100% rename from contract/crypto_module.h rename to contract/vm/crypto_module.h diff --git a/contract/vm/db_module.c b/contract/vm/db_module.c new file mode 100644 index 000000000..ed2f71ee4 --- /dev/null +++ b/contract/vm/db_module.c @@ -0,0 +1,643 @@ +#include +#include +#include +#include +#include "vm.h" +//#include "sqlcheck.h" +#include "bignum_module.h" +#include "util.h" +#include "../db_msg.h" +#include "_cgo_export.h" + +#define RESOURCE_PSTMT_KEY "_RESOURCE_PSTMT_KEY_" +#define RESOURCE_RS_KEY "_RESOURCE_RS_KEY_" + +extern void checkLuaExecContext(lua_State *L); + +static int append_resource(lua_State *L, const char *key, void *data) { + int refno; + if (luaL_findtable(L, LUA_REGISTRYINDEX, key, 0) != NULL) { + luaL_error(L, "cannot find the environment of the db module"); + } + /* tab */ + lua_pushlightuserdata(L, data); /* tab pstmt */ + refno = luaL_ref(L, -2); /* tab */ + lua_pop(L, 1); /* remove tab */ + return refno; +} + +#define DB_PSTMT_ID "__db_pstmt__" + +typedef struct { + int id; + int closed; + int colcnt; + int refno; +} db_pstmt_t; + +#define DB_RS_ID "__db_rs__" + +typedef struct { + int query_id; + int closed; + int nc; + int refno; +} db_rs_t; + + +static void send_vm_api_request(lua_State *L, char *method, buffer *args, rresponse *resp) { + luaSendRequest(L, method, args, resp); +} + +static db_rs_t *get_db_rs(lua_State *L, int pos) { + db_rs_t *rs = luaL_checkudata(L, pos, DB_RS_ID); + if (rs->closed) { + luaL_error(L, "resultset is closed"); + } + return rs; +} + +static int db_rs_tostr(lua_State *L) { + db_rs_t *rs = luaL_checkudata(L, 1, DB_RS_ID); + if (rs->closed) { + lua_pushfstring(L, "resultset is closed"); + } else { + lua_pushfstring(L, "resultset{query_id=%d}", rs->query_id); + } + return 1; +} + +static int db_rs_get(lua_State *L) { + buffer buf = {0}, *req = &buf; + rresponse resp = {0}, *response = &resp; + db_rs_t *rs = get_db_rs(L, 1); + int count=0; + + add_int(req, rs->query_id); + + send_vm_api_request(L, "rsGet", req, response); + free_buffer(req); + if (response->error) { + lua_pushfstring(L, "get failed: %s", response->error); + free_response(response); + lua_error(L); + } + + char *ptr = NULL; + int len; + while (ptr = get_next_item(&response->result, ptr, &len)) { + char type = get_type(ptr, len); + ptr += 1; len -= 1; + switch (type) { + case 'b': + lua_pushboolean(L, read_bool(ptr)); + break; + case 'i': + lua_pushinteger(L, read_int(ptr)); + break; + case 'l': + lua_pushinteger(L, read_int64(ptr)); + break; + case 'd': + lua_pushnumber(L, read_double(ptr)); + break; + case 's': + lua_pushlstring(L, ptr, len - 1); + break; + case 'n': + lua_pushnil(L); + break; + default: + lua_pushnil(L); + } + count++; + } + + free_response(response); + return count; +} + +static int db_rs_colcnt(lua_State *L) { + db_rs_t *rs = get_db_rs(L, 1); + lua_pushinteger(L, rs->nc); + return 1; +} + +static void db_rs_close(lua_State *L, db_rs_t *rs, int remove) { + if (rs->closed) { + return; + } + rs->closed = 1; + if (remove) { + if (luaL_findtable(L, LUA_REGISTRYINDEX, RESOURCE_RS_KEY, 0) != NULL) { + luaL_error(L, "cannot find the environment of the db module"); + } + luaL_unref(L, -1, rs->refno); + lua_pop(L, 1); + } +} + +static int db_rs_next(lua_State *L) { + buffer buf = {0}, *req = &buf; + rresponse resp = {0}, *response = &resp; + db_rs_t *rs = get_db_rs(L, 1); + int rc; + + add_int(req, rs->query_id); + + send_vm_api_request(L, "rsNext", req, response); + free_buffer(req); + if (response->error) { + db_rs_close(L, rs, 1); + lua_pushfstring(L, "next failed: %s", response->error); + free_response(response); + lua_error(L); + } + + bool has_more = get_bool(&response->result, 1); + + if (has_more) { + lua_pushboolean(L, 1); + } else { + db_rs_close(L, rs, 1); + lua_pushboolean(L, 0); + } + + free_response(response); + return 1; +} + +static int db_rs_gc(lua_State *L) { + db_rs_close(L, luaL_checkudata(L, 1, DB_RS_ID), 1); + return 0; +} + +static db_pstmt_t *get_db_pstmt(lua_State *L, int pos) { + db_pstmt_t *pstmt = luaL_checkudata(L, pos, DB_PSTMT_ID); + if (pstmt->closed) { + luaL_error(L, "prepared statement is closed"); + } + return pstmt; +} + +static int db_pstmt_tostr(lua_State *L) { + db_pstmt_t *pstmt = luaL_checkudata(L, 1, DB_PSTMT_ID); + if (pstmt->closed) { + lua_pushfstring(L, "prepared statement is closed"); + } else { + lua_pushfstring(L, "prepared statement{id=%d}", pstmt->id); + } + return 1; +} + +static int add_parameters(lua_State *L, buffer *req) { + buffer buf = {0}, *params = &buf; + int rc, i; + int argc = lua_gettop(L) - 1; + + for (i = 1; i <= argc; i++) { + int t, b, n = i + 1; + const char *s; + size_t l; + + luaL_checkany(L, n); + t = lua_type(L, n); + + switch (t) { + case LUA_TNUMBER: + if (luaL_isinteger(L, n)) { + lua_Integer d = lua_tointeger(L, n); + add_int64(params, d); + } else { + lua_Number d = lua_tonumber(L, n); + add_double(params, d); + } + break; + case LUA_TSTRING: + s = lua_tolstring(L, n, &l); + add_string_ex(params, s, l); + break; + case LUA_TBOOLEAN: + b = lua_toboolean(L, n); + if (b) { + add_int(params, 1); + } else { + add_int(params, 0); + } + break; + case LUA_TNIL: + add_null(params); + break; + case LUA_TUSERDATA: + if (lua_isbignumber(L, n)) { + long int d = lua_get_bignum_si(L, n); + if (d == 0 && lua_bignum_is_zero(L, n) != 0) { + char *s = lua_get_bignum_str(L, n); + lua_pushfstring(L, "bignum value overflow for binding %s", s); + free(s); + free_buffer(params); + return -1; + } + add_int64(params, d); + break; + } + default: + lua_pushfstring(L, "unsupported type: %s", lua_typename(L, n)); + free_buffer(params); + return -1; + } + + if (rc != 0) { + lua_pushfstring(L, "add parameter failed"); + free_buffer(params); + return -1; + } + } + + add_bytes(req, params->ptr, params->len); + free_buffer(params); + return 0; +} + +static int db_pstmt_exec(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + int rc; + db_pstmt_t *pstmt = get_db_pstmt(L, 1); + + if (!pstmt || pstmt->id == 0) { + luaL_error(L, "invalid prepared statement"); + } + + add_int(args, pstmt->id); + + rc = add_parameters(L, args); + if (rc == -1) { + free_buffer(args); + lua_error(L); + } + + send_vm_api_request(L, "stmtExec", args, response); + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "exec failed: %s", response->error); + free_response(response); + lua_error(L); + } + + lua_Integer changes = get_int64(&response->result, 1); + lua_pushinteger(L, changes); + free_response(response); + return 1; +} + +static int db_pstmt_query(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + int rc; + db_pstmt_t *pstmt = get_db_pstmt(L, 1); + db_rs_t *rs; + + if (!pstmt || pstmt->id == 0) { + luaL_error(L, "invalid prepared statement"); + } + + checkLuaExecContext(L); + + add_int(args, pstmt->id); + + rc = add_parameters(L, args); + if (rc != 0) { + free_buffer(args); + lua_error(L); + } + + send_vm_api_request(L, "stmtQuery", args, response); + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "query failed: %s", response->error); + free_response(response); + lua_error(L); + } + + // store the query id on the structure + rs = (db_rs_t *) lua_newuserdata(L, sizeof(db_rs_t)); + luaL_getmetatable(L, DB_RS_ID); + lua_setmetatable(L, -2); + rs->query_id = get_int(&response->result, 1); + rs->nc = get_int(&response->result, 2); + rs->closed = 0; + rs->refno = append_resource(L, RESOURCE_RS_KEY, (void *)rs); + + free_response(response); + return 1; +} + +static void get_column_meta(lua_State *L, bytes *result) { + bytes names, types; + get_bytes(result, 1, &names); + get_bytes(result, 2, &types); + int colcnt = get_count(&names); + int i; + + lua_createtable(L, 0, 2); + lua_pushinteger(L, colcnt); + lua_setfield(L, -2, "colcnt"); + if (colcnt > 0) { + lua_createtable(L, colcnt, 0); /* colinfos names */ + lua_createtable(L, colcnt, 0); /* colinfos names decltypes */ + } else { + lua_pushnil(L); + lua_pushnil(L); + } + + for (i = 1; i <= colcnt; i++) { + char *name = get_string(&names, i); + lua_pushstring(L, name); + lua_rawseti(L, -3, i); + + char *decltype = get_string(&types, i); + lua_pushstring(L, decltype); + lua_rawseti(L, -2, i); + } + + lua_setfield(L, -3, "decltypes"); + lua_setfield(L, -2, "names"); +} + +static int db_pstmt_column_info(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + db_pstmt_t *pstmt = get_db_pstmt(L, 1); + + checkLuaExecContext(L); + + add_int(args, pstmt->id); + + send_vm_api_request(L, "stmtColumnInfo", args, response); + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "column_info failed: %s", response->error); + free_response(response); + lua_error(L); + } + + get_column_meta(L, &response->result); + free_response(response); + return 1; +} + +static int db_pstmt_bind_param_cnt(lua_State *L) { + db_pstmt_t *pstmt = get_db_pstmt(L, 1); + checkLuaExecContext(L); + lua_pushinteger(L, pstmt->colcnt); + return 1; +} + +static void db_pstmt_close(lua_State *L, db_pstmt_t *pstmt, int remove) { + if (pstmt->closed) { + return; + } + pstmt->closed = 1; + if (remove) { + if (luaL_findtable(L, LUA_REGISTRYINDEX, RESOURCE_PSTMT_KEY, 0) != NULL) { + luaL_error(L, "cannot find the environment of the db module"); + } + luaL_unref(L, -1, pstmt->refno); + lua_pop(L, 1); + } +} + +static int db_pstmt_gc(lua_State *L) { + db_pstmt_close(L, luaL_checkudata(L, 1, DB_PSTMT_ID), 1); + return 0; +} + +static int db_exec(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + const char *sql; + int rc; + + sql = luaL_checkstring(L, 1); + add_string(args, sql); + + rc = add_parameters(L, args); + if (rc == -1) { + free_buffer(args); + lua_error(L); + } + + send_vm_api_request(L, "dbExec", args, response); + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "exec failed: %s", response->error); + free_response(response); + lua_error(L); + } + + lua_Integer changes = get_int64(&response->result, 1); + lua_pushinteger(L, changes); + free_response(response); + return 1; +} + +static int db_query(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + db_rs_t *rs; + const char *sql; + int rc; + + checkLuaExecContext(L); + + sql = luaL_checkstring(L, 1); + add_string(args, sql); + + rc = add_parameters(L, args); + if (rc == -1) { + free_buffer(args); + lua_error(L); + } + + send_vm_api_request(L, "dbQuery", args, response); // it could release args + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "query failed: %s", response->error); + free_response(response); + lua_error(L); + } + + // store the query id on the structure + rs = (db_rs_t *) lua_newuserdata(L, sizeof(db_rs_t)); + luaL_getmetatable(L, DB_RS_ID); + lua_setmetatable(L, -2); + rs->query_id = get_int(&response->result, 1); + rs->nc = get_int(&response->result, 2); + rs->closed = 0; + rs->refno = append_resource(L, RESOURCE_RS_KEY, (void *)rs); + + free_response(response); + return 1; +} + +static int db_prepare(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + const char *sql; + db_pstmt_t *pstmt; + + checkLuaExecContext(L); + + sql = luaL_checkstring(L, 1); + add_string(args, sql); + + send_vm_api_request(L, "dbPrepare", args, response); + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "prepare failed: %s", response->error); + free_response(response); + lua_error(L); + } + + // save the prepared statement id on the structure + pstmt = (db_pstmt_t *) lua_newuserdata(L, sizeof(db_pstmt_t)); + luaL_getmetatable(L, DB_PSTMT_ID); + lua_setmetatable(L, -2); + pstmt->id = get_int(&response->result, 1); + pstmt->colcnt = get_int(&response->result, 2); + pstmt->closed = 0; + pstmt->refno = append_resource(L, RESOURCE_PSTMT_KEY, (void *)pstmt); + + return 1; +} + +static int db_get_snapshot(lua_State *L) { + rresponse resp = {0}, *response = &resp; + + checkLuaExecContext(L); + + send_vm_api_request(L, "dbGetSnapshot", NULL, response); + if (response->error) { + lua_pushfstring(L, "get_snapshot failed: %s", response->error); + free_response(response); + lua_error(L); + } + + lua_pushstring(L, get_string(&response->result, 1)); + free_response(response); + return 1; +} + +static int db_open_with_snapshot(lua_State *L) { + buffer buf = {0}, *args = &buf; + rresponse resp = {0}, *response = &resp; + char *snapshot = (char *) luaL_checkstring(L, 1); + + checkLuaExecContext(L); + + add_string(args, snapshot); + send_vm_api_request(L, "dbOpenWithSnapshot", args, response); + free_buffer(args); + if (response->error) { + lua_pushfstring(L, "open_with_snapshot failed: %s", response->error); + free_response(response); + lua_error(L); + } + free_response(response); + + return 0; +} + +static int db_last_insert_rowid(lua_State *L) { + rresponse resp = {0}, *response = &resp; + + checkLuaExecContext(L); + + send_vm_api_request(L, "lastInsertRowid", NULL, response); + if (response->error) { + lua_pushfstring(L, "last_insert_rowid failed: %s", response->error); + free_response(response); + lua_error(L); + } + + lua_Integer id = get_int64(&response->result, 1); + + lua_pushinteger(L, id); + return 1; +} + +int lua_db_release_resource(lua_State *L) { + lua_getfield(L, LUA_REGISTRYINDEX, RESOURCE_RS_KEY); + if (lua_istable(L, -1)) { + /* T */ + lua_pushnil(L); /* T nil(key) */ + while (lua_next(L, -2)) { + if (lua_islightuserdata(L, -1)) { + db_rs_close(L, (db_rs_t *) lua_topointer(L, -1), 0); + } + lua_pop(L, 1); + } + lua_pop(L, 1); + } + lua_getfield(L, LUA_REGISTRYINDEX, RESOURCE_PSTMT_KEY); + if (lua_istable(L, -1)) { + /* T */ + lua_pushnil(L); /* T nil(key) */ + while (lua_next(L, -2)) { + if (lua_islightuserdata(L, -1)) { + db_pstmt_close(L, (db_pstmt_t *) lua_topointer(L, -1), 0); + } + lua_pop(L, 1); + } + lua_pop(L, 1); + } + return 0; +} + +static const luaL_Reg rs_methods[] = { + {"next", db_rs_next}, + {"get", db_rs_get}, + {"colcnt", db_rs_colcnt}, + {"__tostring", db_rs_tostr}, + {"__gc", db_rs_gc}, + {NULL, NULL} +}; + +static const luaL_Reg pstmt_methods[] = { + {"exec", db_pstmt_exec}, + {"query", db_pstmt_query}, + {"column_info", db_pstmt_column_info}, + {"bind_param_cnt", db_pstmt_bind_param_cnt}, + {"__tostring", db_pstmt_tostr}, + {"__gc", db_pstmt_gc}, + {NULL, NULL} +}; + +static const luaL_Reg db_lib[] = { + {"exec", db_exec}, + {"query", db_query}, + {"prepare", db_prepare}, + {"getsnap", db_get_snapshot}, + {"open_with_snapshot", db_open_with_snapshot}, + {"last_insert_rowid", db_last_insert_rowid}, + {NULL, NULL} +}; + +int luaopen_db(lua_State *L) { + + luaL_newmetatable(L, DB_RS_ID); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + luaL_register(L, NULL, rs_methods); + + luaL_newmetatable(L, DB_PSTMT_ID); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + luaL_register(L, NULL, pstmt_methods); + + luaL_register(L, "db", db_lib); + + lua_pop(L, 3); + return 1; +} diff --git a/contract/vm/db_module.h b/contract/vm/db_module.h new file mode 100644 index 000000000..c8b425719 --- /dev/null +++ b/contract/vm/db_module.h @@ -0,0 +1,9 @@ +#ifndef _DB_MODULE_H +#define _DB_MODULE_H + +#include "lua.h" + +extern int luaopen_db(lua_State *L); +extern int lua_db_release_resource(lua_State *L); + +#endif /* _DB_MODULE_H */ diff --git a/contract/vm/db_msg_wrapper.c b/contract/vm/db_msg_wrapper.c new file mode 100644 index 000000000..3d396aa11 --- /dev/null +++ b/contract/vm/db_msg_wrapper.c @@ -0,0 +1 @@ +#include "../db_msg.c" diff --git a/contract/debug.c b/contract/vm/debug.c similarity index 100% rename from contract/debug.c rename to contract/vm/debug.c diff --git a/contract/debug.h b/contract/vm/debug.h similarity index 100% rename from contract/debug.h rename to contract/vm/debug.h diff --git a/contract/hook.go b/contract/vm/hook.go similarity index 65% rename from contract/hook.go rename to contract/vm/hook.go index 482b725ad..9190cc0c5 100644 --- a/contract/hook.go +++ b/contract/vm/hook.go @@ -1,7 +1,7 @@ //go:build !Debug // +build !Debug -package contract +package main /* #include "vm.h" @@ -9,14 +9,11 @@ package contract import "C" func (ce *executor) setCountHook(limit C.int) { - if ce == nil || - ce.L == nil || - ce.err != nil || - ce.ctx.IsGasSystem() { + if ce == nil || ce.L == nil || ce.err != nil || IsGasSystem() { C.vm_set_timeout_hook(ce.L) return } - if ce.ctx.blockInfo.ForkVersion >= 2 { + if hardforkVersion >= 2 { C.vm_set_timeout_count_hook(ce.L, limit) } else { C.vm_set_count_hook(ce.L, limit) diff --git a/contract/vm/main.go b/contract/vm/main.go new file mode 100644 index 000000000..994848dcf --- /dev/null +++ b/contract/vm/main.go @@ -0,0 +1,266 @@ +package main + +import ( + "fmt" + "strconv" + "encoding/binary" + "os" + "net" + "time" + "github.com/aergoio/aergo/v2/contract/msg" +) + +/* +#include "vm.h" +#include "_cgo_export.h" +*/ + +// global variables + +var hardforkVersion int +var isPubNet bool + +var secretKey string +var conn *net.UnixConn +var timedout bool + + +func main(){ + var socketName string + var err error + + args := os.Args + + // check if args are empty + if len(args) != 5 { + fmt.Println("Usage: aergovm ") + return + } + + // get the hardfork version from command line + hardforkVersion, err = strconv.Atoi(args[1]) + if err != nil { + fmt.Println("Error: Invalid hardfork version") + return + } + + // get PubNet from command line + isPubNet, err = strconv.ParseBool(args[2]) + if err != nil { + fmt.Println("Error: Invalid PubNet") + return + } + + // get socket name from command line + socketName = args[3] + if socketName == "" { + fmt.Println("Error: Invalid socket name") + return + } + + // get secret key from command line + secretKey = args[4] + if secretKey == "" { + fmt.Println("Error: Invalid secret key") + return + } + + // initialize Lua modules + InitializeVM() + + // connect to the server + err = connectToServer(socketName) + if err != nil { + fmt.Println("Error: Could not connect to server") + return + } + + // send ready message + sendReadyMessage() + + // wait for commands from the server + MessageLoop() + + // exit + closeApp(0) +} + +// connect to the server using an abstract unix domain socket (they start with a null byte) +func connectToServer(socketName string) (err error) { + rawConn, err := net.Dial("unix", "\x00"+socketName) + if err != nil { + return err + } + conn = rawConn.(*net.UnixConn) + return nil +} + +func sendReadyMessage() { + message := []byte("ready") + /*/ encrypt the message + message, err = msg.Encrypt(message, secretKey) + if err != nil { + fmt.Printf("Error: failed to encrypt message: %v\n", err) + return + } */ + msg.SendMessage(conn, message) +} + +func MessageLoop() { + + for { + // wait for command to execute, with null deadline + message, err := msg.WaitForMessage(conn, time.Time{}) + if err != nil { + fmt.Printf("Error: failed to receive message: %v\n", err) + return + } + /*/ decrypt the message + message, err = msg.Decrypt(message, secretKey) + if err != nil { + fmt.Printf("Error: failed to decrypt message: %v\n", err) + return + } */ + // deserialize the message + args, err := msg.DeserializeMessage(message) + if err != nil { + fmt.Printf("Error: failed to deserialize message: %v\n", err) + return + } + command := args[0] + args = args[1:] + fmt.Println("Received message: ", command, args) + // process the request + result, err := processCommand(command, args) + //if err != nil { + // return "", err + //} + // serialize the result and error + response := msg.SerializeMessage(result, err.Error()) + /*/ encrypt the response + response, err = msg.Encrypt(response, secretKey) + if err != nil { + fmt.Printf("Error: failed to encrypt message: %v\n", err) + return + } */ + // send the response + err = msg.SendMessage(conn, response) + if err != nil { + fmt.Printf("Error: failed to send message: %v\n", err) + return + } + } + +} + +func processCommand(command string, args []string) (string, error) { + + switch command { + case "execute": + if len(args) != 9 { + fmt.Println("execute: invalid number of arguments") + sendMessage([]string{"", "execute: invalid number of arguments"}) + closeApp(1) + } + address := args[0] + code := args[1] + fname := args[2] + fargs := args[3] + gasStr := args[4] + caller := args[5] + hasParent, err := strconv.ParseBool(args[6]) + if err != nil { + fmt.Println("execute: invalid hasParent argument") + sendMessage([]string{"", "execute: invalid hasParent argument"}) + closeApp(1) + } + isFeeDelegation, err := strconv.ParseBool(args[7]) + if err != nil { + fmt.Println("execute: invalid isFeeDelegation argument") + sendMessage([]string{"", "execute: invalid isFeeDelegation argument"}) + closeApp(1) + } + abiError := args[8] + + var gas uint64 + gasBytes := []byte(gasStr) + if len(gasBytes) != 8 { + fmt.Println("execute: invalid gas string length") + sendMessage([]string{"", "execute: invalid gas string length"}) + closeApp(1) + } + gas = binary.LittleEndian.Uint64(gasBytes) + + res, err, usedGas := Execute(address, code, fname, fargs, gas, caller, hasParent, isFeeDelegation, abiError) + + // encode the gas together with the result + gasBytes = make([]byte, 8) + binary.LittleEndian.PutUint64(gasBytes, usedGas) + res = string(gasBytes) + res + + var errStr string + if err != nil { + errStr = err.Error() + } + err = sendApiMessage("return", []string{res, errStr}) + if err != nil { + fmt.Printf("execute: failed to send message: %v\n", err) + closeApp(1) + } + closeApp(0) + + case "compile": + if len(args) != 2 { + fmt.Println("compile: invalid number of arguments") + sendMessage([]string{"", "compile: invalid number of arguments"}) + closeApp(1) + } + code := args[0] + hasParent, err := strconv.ParseBool(args[1]) + if err != nil { + fmt.Println("compile: invalid hasParent argument") + sendMessage([]string{"", "compile: invalid hasParent argument"}) + closeApp(1) + } + + bytecodeAbi, err := Compile(code, hasParent) + + var errStr string + if err != nil { + errStr = err.Error() + } + sendMessage([]string{string(bytecodeAbi), errStr}) + closeApp(0) + + // if the contract is executing, this can only be received if using another thread + // or if checking for incoming messages in regular intervals (expensive operation) + case "timeout": + timedout = true + return "", nil + + case "exit": + closeApp(0) + + } + + fmt.Println("aergovm: unknown command: ", command) + sendMessage([]string{"", "aergovm: unknown command: " + command}) + closeApp(1) + return "", nil +} + +func closeApp(ret int) { + + if conn != nil { + // ensure the data is sent before closing the connection + err := conn.CloseWrite() + if err != nil { + fmt.Printf("aergovm: failed to close write end of connection: %v\n", err) + os.Exit(1) + } + // close the connection + conn.Close() + } + + os.Exit(ret) +} diff --git a/contract/vm/main_test.go b/contract/vm/main_test.go new file mode 100644 index 000000000..26bcb7226 --- /dev/null +++ b/contract/vm/main_test.go @@ -0,0 +1,319 @@ +package main + +import ( + "encoding/binary" + "bytes" + "net" + "os" + "os/exec" + "testing" + "time" + + "github.com/aergoio/aergo/v2/cmd/aergoluac/util" + "github.com/aergoio/aergo/v2/contract/msg" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" +) + +var contractCode = ` +state.var { + kv = state.map() +} +function add(a, b) + return a + b +end +function set(key, value) + kv[key] = value +end +function get(key) + return kv[key] +end +function send(to, amount) + return contract.send(to, amount) +end +function call(...) + return contract.call(...) +end +function call_with_send(amount, ...) + return contract.call.value(amount)(...) +end +function delegatecall(...) + return contract.delegatecall(...) +end +function deploy(...) + return contract.deploy(...) +end +function deploy_with_send(amount, ...) + return contract.deploy.value(amount)(...) +end +function get_info() + return system.getContractID(), contract.balance(), system.getAmount(), system.getSender(), system.getOrigin(), system.isFeeDelegation() +end +function events() + contract.event('first', 123, 'abc') + contract.event('second', '456', 7.89) +end +abi.register(add, set, get, send, call, call_with_send, delegatecall, deploy, deploy_with_send, get_info, events) +` + +var initDone = false + +func compileVmExecutable(t *testing.T) { + + // Change the current working directory to the root directory + os.Chdir("../..") + + // Compile the VM binary + t.Log("Compiling the VM binary") + cmd := exec.Command("make", "aergovm") + err := cmd.Run() + require.NoError(t, err, "Failed to compile the VM executable") + +} + +func NewVmInstance(t *testing.T) (*exec.Cmd, net.Conn, chan struct{}) { + + if !initDone { + compileVmExecutable(t) + initDone = true + } + + // Set up the Unix domain socket + t.Log("Creating Unix domain socket") + socketName := "\x00test_socket" + rawListener, err := net.Listen("unix", socketName) + require.NoError(t, err, "Failed to create Unix domain socket") + defer rawListener.Close() + listener, ok := rawListener.(*net.UnixListener) + require.True(t, ok, "Failed to assign listener to *net.UnixListener") + + // Start the VM process + t.Log("Starting VM process") + vmCmd := exec.Command("bin/aergovm", "3", "false", socketName[1:], "test_secret_key") + + // Capture stdout and stderr + var stdout, stderr bytes.Buffer + vmCmd.Stdout = &stdout + vmCmd.Stderr = &stderr + err = vmCmd.Start() + require.NoError(t, err, "Failed to start VM process") + + done := make(chan struct{}) + + go func() { + err := vmCmd.Wait() + select { + case <-done: + // Test completed successfully, do nothing + default: + // VM process exited before test completion + if err != nil { + t.Errorf("VM process exited unexpectedly: %v", err) + t.Logf("Stderr: %s", stderr.String()) + t.Logf("Stdout: %s", stdout.String()) + t.FailNow() + } + } + }() + + // Wait for and accept the connection from VM with a timeout + t.Log("Waiting for connection from VM") + err = listener.SetDeadline(time.Now().Add(3 * time.Second)) + require.NoError(t, err, "Failed to set listener deadline") + conn, err := listener.AcceptUnix() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + t.Fatal("Timed out waiting for VM to connect") + } + require.NoError(t, err, "Failed to accept connection from VM") + } + + t.Log("Connection accepted from VM") + + // Reset the deadline + err = listener.SetDeadline(time.Time{}) + require.NoError(t, err, "Failed to reset listener deadline") + + t.Log("Waiting for ready message") + readyMsg, err := msg.WaitForMessage(conn, time.Now().Add(2*time.Second)) + require.NoError(t, err, "Failed to receive ready message") + require.Equal(t, "ready", string(readyMsg), "Unexpected ready message") + + return vmCmd, conn, done +} + +func TestVMExecutionBasicPlainCode(t *testing.T) { + + vmCmd, conn, done := NewVmInstance(t) + + // Test the execute command + t.Log("Sending execute command") + executeCmd := []string{"execute", "contract_address", contractCode, "add", `[123,456]`, "\x00\x00\x00\x01\x00\x00\x00\x00", "test_caller", "false", "false", ""} + serializedCmd := msg.SerializeMessage(executeCmd...) + err := msg.SendMessage(conn, serializedCmd) + require.NoError(t, err, "Failed to send execute command") + + t.Log("Waiting for execute response") + response, err := msg.WaitForMessage(conn, time.Now().Add(250*time.Millisecond)) + require.NoError(t, err, "Failed to receive execute response") + + // Deserialize the response + args, err := msg.DeserializeMessage(response) + require.NoError(t, err) + require.Len(t, args, 4, "Unexpected number of response arguments") + command := args[0] + result := args[1] + errStr := args[2] + inView := args[3] + + // Extract used gas and result + require.Greater(t, len(result), 8, "expected to contain encoded gas") + usedGas := binary.LittleEndian.Uint64([]byte(result[:8])) + result = result[8:] + + assert.Equal(t, "return", command) + assert.Equal(t, "579", result) + assert.Equal(t, "", errStr) + assert.Equal(t, uint64(3648), usedGas) + assert.Equal(t, "0", inView) + + vmCmd.Process.Kill() + conn.Close() + // Signal that the test is done + done <- struct{}{} + close(done) +} + +func TestVMCompileAndExecutionBasic(t *testing.T) { + + vmCmd, conn, done := NewVmInstance(t) + + // Compile the contract + t.Log("Sending compile command") + compileCmd := []string{"compile", contractCode, "false"} + serializedCmd := msg.SerializeMessage(compileCmd...) + err := msg.SendMessage(conn, serializedCmd) + require.NoError(t, err, "Failed to send compile command") + + t.Log("Waiting for compile response") + response, err := msg.WaitForMessage(conn, time.Now().Add(250*time.Millisecond)) + require.NoError(t, err, "Failed to receive compile response") + + args, err := msg.DeserializeMessage(response) + require.NoError(t, err) + require.Len(t, args, 2, "Unexpected number of response arguments") + bytecodeAbi := args[0] + errMsg := args[1] + require.Equal(t, "", errMsg) + + vmCmd.Process.Kill() + conn.Close() + // Signal that the test is done + done <- struct{}{} + close(done) + + + bytecode := util.LuaCode(bytecodeAbi).ByteCode() + + + vmCmd, conn, done = NewVmInstance(t) + + // Test the execute command + t.Log("Sending execute command") + executeCmd := []string{"execute", "contract_address", string(bytecode), "add", `[123,456]`, "\x00\x00\x00\x01\x00\x00\x00\x00", "test_caller", "false", "false", ""} + serializedCmd = msg.SerializeMessage(executeCmd...) + err = msg.SendMessage(conn, serializedCmd) + require.NoError(t, err, "Failed to send execute command") + + t.Log("Waiting for execute response") + response, err = msg.WaitForMessage(conn, time.Now().Add(250*time.Millisecond)) + require.NoError(t, err, "Failed to receive execute response") + + // Deserialize the response + args, err = msg.DeserializeMessage(response) + require.NoError(t, err) + require.Len(t, args, 4, "Unexpected number of response arguments") + command := args[0] + result := args[1] + errStr := args[2] + inView := args[3] + + // Extract used gas and result + require.GreaterOrEqual(t, len(result), 8, "expected to contain encoded gas") + usedGas := binary.LittleEndian.Uint64([]byte(result[:8])) + result = result[8:] + + assert.Equal(t, "return", command) + assert.Equal(t, "579", result) + assert.Equal(t, "", errStr) + assert.Equal(t, uint64(5856), usedGas) + assert.Equal(t, "0", inView) + + vmCmd.Process.Kill() + conn.Close() + // Signal that the test is done + done <- struct{}{} + close(done) +} + +func TestVMExecutionWithCallback(t *testing.T) { + + vmCmd, conn, done := NewVmInstance(t) + + // Test the execute command + t.Log("Sending execute command") + executeCmd := []string{"execute", "contract_address", contractCode, "send", `["test_to","9876543210"]`, "\x00\x00\x00\x01\x00\x00\x00\x00", "test_caller", "false", "false", ""} + serializedCmd := msg.SerializeMessage(executeCmd...) + err := msg.SendMessage(conn, serializedCmd) + require.NoError(t, err, "Failed to send execute command") + + t.Log("Waiting for execute response") + response, err := msg.WaitForMessage(conn, time.Now().Add(250*time.Millisecond)) + require.NoError(t, err, "Failed to receive execute response") + + // Deserialize the response + args, err := msg.DeserializeMessage(response) + require.NoError(t, err) + require.Len(t, args, 5) + require.Equal(t, "send", args[0]) + require.Equal(t, "test_to", args[1]) + require.Equal(t, "9876543210", args[2]) + require.Equal(t, "\x84\xf0\xff\x00\x00\x00\x00\x00", args[3]) + require.Equal(t, "0", args[4]) + + // Send response back to the VM instance + args = []string{"\x09\x00\x01\x00\x00\x00\x00\x00", ""} + message := msg.SerializeMessage(args...) + err = msg.SendMessage(conn, message) + require.NoError(t, err, "Failed to send response back to VM") + + t.Log("Waiting for execute response") + response, err = msg.WaitForMessage(conn, time.Now().Add(250*time.Millisecond)) + require.NoError(t, err, "Failed to receive execute response") + + // Deserialize the response + args, err = msg.DeserializeMessage(response) + require.NoError(t, err) + require.Len(t, args, 4, "Unexpected number of response arguments") + command := args[0] + result := args[1] + errStr := args[2] + inView := args[3] + + // Extract used gas and result + require.GreaterOrEqual(t, len(result), 8, "expected to contain encoded gas") + usedGas := binary.LittleEndian.Uint64([]byte(result[:8])) + result = result[8:] + + assert.Equal(t, "return", command) + assert.Equal(t, "", result) + assert.Equal(t, "", errStr) + assert.Equal(t, uint64(69509), usedGas) + assert.Equal(t, "0", inView) + + vmCmd.Process.Kill() + conn.Close() + // Signal that the test is done + done <- struct{}{} + close(done) +} diff --git a/contract/name_module.c b/contract/vm/name_module.c similarity index 53% rename from contract/name_module.c rename to contract/vm/name_module.c index 457b77d87..3788d03eb 100644 --- a/contract/name_module.c +++ b/contract/vm/name_module.c @@ -4,29 +4,23 @@ #include "util.h" #include "_cgo_export.h" -extern int getLuaExecContext(lua_State *L); +extern void checkLuaExecContext(lua_State *L); static int resolve(lua_State *L) { - char *name, *ret; - int service = getLuaExecContext(L); + struct luaNameResolve_return ret; + char *name; + checkLuaExecContext(L); lua_gasuse(L, 100); name = (char *)luaL_checkstring(L, 1); - ret = luaNameResolve(L, service, name); - if (ret == NULL) { - lua_pushnil(L); - } else { - // if the returned string starts with `[`, it's an error - if (ret[0] == '[') { - strPushAndRelease(L, ret); - luaL_throwerror(L); - } else { - strPushAndRelease(L, ret); - } + ret = luaNameResolve(L, name); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); } - + strPushAndRelease(L, ret.r0); return 1; } diff --git a/contract/name_module.h b/contract/vm/name_module.h similarity index 100% rename from contract/name_module.h rename to contract/vm/name_module.h diff --git a/contract/state_module.c b/contract/vm/state_module.c similarity index 100% rename from contract/state_module.c rename to contract/vm/state_module.c diff --git a/contract/state_module.h b/contract/vm/state_module.h similarity index 100% rename from contract/state_module.h rename to contract/vm/state_module.h diff --git a/contract/system_module.c b/contract/vm/system_module.c similarity index 80% rename from contract/system_module.c rename to contract/vm/system_module.c index 45f2b480f..0d2790c16 100644 --- a/contract/system_module.c +++ b/contract/vm/system_module.c @@ -8,18 +8,18 @@ #define STATE_DB_KEY_PREFIX "_" -extern int getLuaExecContext(lua_State *L); +extern void checkLuaExecContext(lua_State *L); static int systemPrint(lua_State *L) { char *jsonValue; - int service = getLuaExecContext(L); + checkLuaExecContext(L); lua_gasuse(L, 100); jsonValue = lua_util_get_json_from_stack(L, 1, lua_gettop(L), true); if (jsonValue == NULL) { luaL_throwerror(L); } - luaPrint(L, service, jsonValue); + luaPrint(L, jsonValue); free(jsonValue); return 0; } @@ -38,10 +38,10 @@ static char *getDbKey(lua_State *L, int *len) { int setItemWithPrefix(lua_State *L) { char *dbKey; char *jsonValue; - int service = getLuaExecContext(L); char *errStr; int keylen; + checkLuaExecContext(L); lua_gasuse(L, 100); luaL_checkstring(L, 1); @@ -57,12 +57,12 @@ int setItemWithPrefix(lua_State *L) { lua_gasuse_mul(L, GAS_SDATA, strlen(jsonValue)); - if ((errStr = luaSetDB(L, service, dbKey, keylen, jsonValue)) != NULL) { - free(jsonValue); + errStr = luaSetVariable(L, dbKey, keylen, jsonValue); + free(jsonValue); + if (errStr != NULL) { strPushAndRelease(L, errStr); luaL_throwerror(L); } - free(jsonValue); return 0; } @@ -75,12 +75,12 @@ int setItem(lua_State *L) { int getItemWithPrefix(lua_State *L) { char *dbKey; - int service = getLuaExecContext(L); char *jsonValue; char *blkno = NULL; - struct luaGetDB_return ret; + struct luaGetVariable_return ret; int keylen; + checkLuaExecContext(L); lua_gasuse(L, 100); luaL_checkstring(L, 1); @@ -99,7 +99,7 @@ int getItemWithPrefix(lua_State *L) { } dbKey = getDbKey(L, &keylen); - ret = luaGetDB(L, service, dbKey, keylen, blkno); + ret = luaGetVariable(L, dbKey, keylen, blkno); if (ret.r1 != NULL) { strPushAndRelease(L, ret.r1); luaL_throwerror(L); @@ -109,7 +109,7 @@ int getItemWithPrefix(lua_State *L) { } minus_inst_count(L, strlen(ret.r0)); - if (lua_util_json_to_lua(L, ret.r0, false) != 0) { + if (lua_util_json_value_to_lua(L, ret.r0, false) != 0) { strPushAndRelease(L, ret.r0); luaL_error(L, "getItem error : can't convert %s", lua_tostring(L, -1)); } @@ -129,17 +129,17 @@ int getItem(lua_State *L) { int delItemWithPrefix(lua_State *L) { char *dbKey; - int service = getLuaExecContext(L); char *jsonValue; char *ret; int keylen; + checkLuaExecContext(L); lua_gasuse(L, 100); luaL_checkstring(L, 1); luaL_checkstring(L, 2); dbKey = getDbKey(L, &keylen); - ret = luaDelDB(L, service, dbKey, keylen); + ret = luaDelVariable(L, dbKey, keylen); if (ret != NULL) { strPushAndRelease(L, ret); luaL_throwerror(L); @@ -148,64 +148,80 @@ int delItemWithPrefix(lua_State *L) { } static int getSender(lua_State *L) { - int service = getLuaExecContext(L); + checkLuaExecContext(L); char *sender; lua_gasuse(L, 1000); - sender = luaGetSender(L, service); + sender = luaGetSender(L); strPushAndRelease(L, sender); return 1; } static int getTxhash(lua_State *L) { - int service = getLuaExecContext(L); - char *hash; + struct luaGetTxHash_return ret; + checkLuaExecContext(L); lua_gasuse(L, 500); - hash = luaGetHash(L, service); - strPushAndRelease(L, hash); + ret = luaGetTxHash(L); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + strPushAndRelease(L, ret.r0); return 1; } static int getBlockHeight(lua_State *L) { - int service = getLuaExecContext(L); + struct luaGetBlockNo_return ret; + checkLuaExecContext(L); lua_gasuse(L, 300); - lua_pushinteger(L, luaGetBlockNo(L, service)); + ret = luaGetBlockNo(L); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + lua_pushinteger(L, ret.r0); return 1; } static int getTimestamp(lua_State *L) { - int service = getLuaExecContext(L); + struct luaGetTimeStamp_return ret; + checkLuaExecContext(L); lua_gasuse(L, 300); - lua_pushinteger(L, luaGetTimeStamp(L, service)); + ret = luaGetTimeStamp(L); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + lua_pushinteger(L, ret.r0); return 1; } static int getContractID(lua_State *L) { - int service = getLuaExecContext(L); + checkLuaExecContext(L); char *id; lua_gasuse(L, 1000); - id = luaGetContractId(L, service); + id = luaGetContractId(L); strPushAndRelease(L, id); return 1; } static int getCreator(lua_State *L) { - int service = getLuaExecContext(L); - struct luaGetDB_return ret; + checkLuaExecContext(L); + struct luaGetVariable_return ret; int keylen = 7; lua_gasuse(L, 500); - ret = luaGetDB(L, service, "Creator", keylen, 0); + ret = luaGetVariable(L, "Creator", keylen, 0); if (ret.r1 != NULL) { strPushAndRelease(L, ret.r1); luaL_throwerror(L); @@ -218,35 +234,47 @@ static int getCreator(lua_State *L) { } static int getAmount(lua_State *L) { - int service = getLuaExecContext(L); - char *amount; + struct luaGetAmount_return ret; + checkLuaExecContext(L); lua_gasuse(L, 300); - amount = luaGetAmount(L, service); - strPushAndRelease(L, amount); + ret = luaGetAmount(L); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + strPushAndRelease(L, ret.r0); return 1; } static int getOrigin(lua_State *L) { - int service = getLuaExecContext(L); - char *origin; + struct luaGetOrigin_return ret; + checkLuaExecContext(L); lua_gasuse(L, 1000); - origin = luaGetOrigin(L, service); - strPushAndRelease(L, origin); + ret = luaGetOrigin(L); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + strPushAndRelease(L, ret.r0); return 1; } static int getPrevBlockHash(lua_State *L) { - int service = getLuaExecContext(L); - char *hash; + struct luaGetPrevBlockHash_return ret; + checkLuaExecContext(L); lua_gasuse(L, 500); - hash = luaGetPrevBlockHash(L, service); - strPushAndRelease(L, hash); + ret = luaGetPrevBlockHash(L); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + strPushAndRelease(L, ret.r0); return 1; } /* datetime-related functions from lib_os.c. time(NULL) is replaced by blocktime(L) */ @@ -392,18 +420,20 @@ static int os_difftime(lua_State *L) { /* end of datetime functions */ static int lua_random(lua_State *L) { - int service = getLuaExecContext(L); - int min, max; + struct luaRandomInt_return ret; + int min, max, value; + + checkLuaExecContext(L); lua_gasuse(L, 100); switch (lua_gettop(L)) { case 1: + min = 1; max = luaL_checkint(L, 1); if (max < 1) { luaL_error(L, "system.random: the maximum value must be greater than zero"); } - lua_pushinteger(L, luaRandomInt(1, max, service)); break; case 2: min = luaL_checkint(L, 1); @@ -414,18 +444,26 @@ static int lua_random(lua_State *L) { if (min > max) { luaL_error(L, "system.random: the maximum value must be greater than the minimum value"); } - lua_pushinteger(L, luaRandomInt(min, max, service)); break; default: luaL_error(L, "system.random: 1 or 2 arguments required"); break; } + + ret = luaRandomInt(L, min, max); + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); + } + lua_pushinteger(L, ret.r0); return 1; } static int toPubkey(lua_State *L) { - char *address, *ret; + struct luaToPubkey_return ret; + char *address; + checkLuaExecContext(L); lua_gasuse(L, 100); // get the function argument @@ -433,23 +471,17 @@ static int toPubkey(lua_State *L) { // convert the address to public key ret = luaToPubkey(L, address); - if (ret == NULL) { - lua_pushnil(L); - } else { - // if the returned string starts with `[`, it's an error - if (ret[0] == '[') { - strPushAndRelease(L, ret); - luaL_throwerror(L); - } else { - strPushAndRelease(L, ret); - } + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); } - + strPushAndRelease(L, ret.r0); return 1; } static int toAddress(lua_State *L) { - char *pubkey, *ret; + struct luaToAddress_return ret; + char *pubkey; lua_gasuse(L, 100); @@ -458,30 +490,23 @@ static int toAddress(lua_State *L) { // convert the public key to an address ret = luaToAddress(L, pubkey); - if (ret == NULL) { - lua_pushnil(L); - } else { - // if the returned string starts with `[`, it's an error - if (ret[0] == '[') { - strPushAndRelease(L, ret); - luaL_throwerror(L); - } else { - strPushAndRelease(L, ret); - } + if (ret.r1 != NULL) { + strPushAndRelease(L, ret.r1); + luaL_throwerror(L); } - + strPushAndRelease(L, ret.r0); return 1; } static int is_contract(lua_State *L) { char *contract; - int service = getLuaExecContext(L); struct luaIsContract_return ret; + checkLuaExecContext(L); lua_gasuse(L, 100); contract = (char *)luaL_checkstring(L, 1); - ret = luaIsContract(L, service, contract); + ret = luaIsContract(L, contract); if (ret.r1 != NULL) { strPushAndRelease(L, ret.r1); luaL_throwerror(L); @@ -495,19 +520,8 @@ static int is_contract(lua_State *L) { } static int is_fee_delegation(lua_State *L) { - int service = getLuaExecContext(L); - struct luaIsFeeDelegation_return ret; - - ret = luaIsFeeDelegation(L, service); - if (ret.r1 != NULL) { - strPushAndRelease(L, ret.r1); - luaL_throwerror(L); - } - if (ret.r0 == 0) { - lua_pushboolean(L, false); - } else { - lua_pushboolean(L, true); - } + checkLuaExecContext(L); + lua_pushboolean(L, luaIsFeeDelegation()); return 1; } diff --git a/contract/system_module.h b/contract/vm/system_module.h similarity index 100% rename from contract/system_module.h rename to contract/vm/system_module.h diff --git a/contract/utf8_module.c b/contract/vm/utf8_module.c similarity index 100% rename from contract/utf8_module.c rename to contract/vm/utf8_module.c diff --git a/contract/util.c b/contract/vm/util.c similarity index 95% rename from contract/util.c rename to contract/vm/util.c index 56576e2a5..b0bf3611f 100644 --- a/contract/util.c +++ b/contract/vm/util.c @@ -353,6 +353,43 @@ static bool lua_util_dump_json(lua_State *L, int idx, sbuff_t *sbuf, bool json_f static int json_to_lua(lua_State *L, char **start, bool check, bool is_bignum); +// the input is like an array but without the [] characters +static int json_args_to_lua(lua_State *L, char *json, bool check) { + int count = 0; + while(*json != '\0') { + if (json_to_lua(L, &json, check, false) != 0) { + return -1; + } + if (*json == ',') { + ++json; + } else if(*json != '\0') { + return -1; + } + ++count; + } + return count; +} + +int lua_util_json_array_to_lua(lua_State *L, char *json, bool check) { + int count = 0; + if (*json != '[') { + return -1; + } + ++json; + while(*json != ']') { + if (json_to_lua(L, &json, check, false) != 0) { + return -1; + } + if (*json == ',') { + ++json; + } else if(*json != ']') { + return -1; + } + ++count; + } + return count; +} + static int json_array_to_lua_table(lua_State *L, char **start, bool check) { char *json = (*start) + 1; int index = 1; @@ -592,7 +629,7 @@ void minus_inst_count(lua_State *L, int count) { } } -int lua_util_json_to_lua(lua_State *L, char *json, bool check) { +int lua_util_json_value_to_lua(lua_State *L, char *json, bool check) { if (json_to_lua(L, &json, check, false) != 0) { return -1; } @@ -699,7 +736,7 @@ static int lua_json_decode(lua_State *L) { lua_gasuse(L, 50); minus_inst_count(L, strlen(json)); - if (lua_util_json_to_lua(L, json, true) != 0) { + if (lua_util_json_value_to_lua(L, json, true) != 0) { free(json); luaL_error(L, "not proper json format"); } diff --git a/contract/util.h b/contract/vm/util.h similarity index 80% rename from contract/util.h rename to contract/vm/util.h index 29fe9395a..bd62b9af0 100644 --- a/contract/util.h +++ b/contract/vm/util.h @@ -7,7 +7,9 @@ char *lua_util_get_json (lua_State *L, int idx, bool json_form); char *lua_util_get_json_from_stack (lua_State *L, int start, int end, bool json_form); char *lua_util_get_json_array_from_stack (lua_State *L, int start, int end, bool json_form); -int lua_util_json_to_lua (lua_State *L, char *json, bool check); +int lua_util_json_value_to_lua (lua_State *L, char *json, bool check); +int lua_util_json_array_to_lua(lua_State *L, char *json, bool check); + void minus_inst_count(lua_State *L, int count); int luaopen_json(lua_State *L); diff --git a/contract/vm.c b/contract/vm/vm.c similarity index 71% rename from contract/vm.c rename to contract/vm/vm.c index bbdc1aa9b..f8f231b70 100644 --- a/contract/vm.c +++ b/contract/vm/vm.c @@ -19,16 +19,14 @@ const char *VM_INST_COUNT = "__INST_COUNT_"; const int VM_TIMEOUT_INST_COUNT = 200; extern int luaopen_utf8(lua_State *L); + extern void (*lj_internal_view_start)(lua_State *); extern void (*lj_internal_view_end)(lua_State *); -void vm_internal_view_start(lua_State *L); -void vm_internal_view_end(lua_State *L); - -int getLuaExecContext(lua_State *L) { - int service = luaL_service(L); - if (service < 0) - luaL_error(L, "not permitted state referencing at global scope"); - return service; + +void checkLuaExecContext(lua_State *L) { + if (luaL_is_loading(L)) { + luaL_error(L, "state referencing not permitted at global scope"); + } } #ifdef MEASURE @@ -102,8 +100,7 @@ static void preloadModules(lua_State *L) { // used to rollback state and drop events upon error static int pcall(lua_State *L) { int argc = lua_gettop(L); - int service = getLuaExecContext(L); - int num_events = luaGetEventCount(L, service); + checkLuaExecContext(L); struct luaSetRecoveryPoint_return start_seq; int ret; @@ -113,7 +110,7 @@ static int pcall(lua_State *L) { lua_gasuse(L, 300); - start_seq = luaSetRecoveryPoint(L, service); + start_seq = luaSetRecoveryPoint(L); if (start_seq.r0 < 0) { strPushAndRelease(L, start_seq.r1); luaL_throwerror(L); @@ -125,10 +122,13 @@ static int pcall(lua_State *L) { // call the function ret = lua_pcall(L, argc - 1, LUA_MULTRET, 0); - // if failed, drop the events - if (ret != 0) { - if (vm_is_hardfork(L, 4)) { - luaDropEvent(L, service, num_events); + // release the recovery point (on success) or revert the contract state (on error) + if (start_seq.r0 > 0) { + bool is_error = (ret != 0); + char *errStr = luaClearRecovery(L, start_seq.r0, is_error); + if (errStr != NULL) { + strPushAndRelease(L, errStr); + luaL_throwerror(L); } } @@ -141,19 +141,6 @@ static int pcall(lua_State *L) { lua_pushboolean(L, ret == 0); lua_insert(L, 1); - // release the recovery point or revert the contract state - if (start_seq.r0 > 0) { - bool is_error = (ret != 0); - char *errStr = luaClearRecovery(L, service, start_seq.r0, is_error); - if (errStr != NULL) { - if (vm_is_hardfork(L, 4)) { - luaDropEvent(L, service, num_events); - } - strPushAndRelease(L, errStr); - luaL_throwerror(L); - } - } - // return the number of items in the stack return lua_gettop(L); } @@ -162,8 +149,7 @@ static int pcall(lua_State *L) { // used to rollback state and drop events upon error static int xpcall(lua_State *L) { int argc = lua_gettop(L); - int service = getLuaExecContext(L); - int num_events = luaGetEventCount(L, service); + checkLuaExecContext(L); struct luaSetRecoveryPoint_return start_seq; int ret, errfunc; @@ -173,7 +159,7 @@ static int xpcall(lua_State *L) { lua_gasuse(L, 300); - start_seq = luaSetRecoveryPoint(L, service); + start_seq = luaSetRecoveryPoint(L); if (start_seq.r0 < 0) { strPushAndRelease(L, start_seq.r1); luaL_throwerror(L); @@ -203,10 +189,13 @@ static int xpcall(lua_State *L) { // call the function ret = lua_pcall(L, argc - 2, LUA_MULTRET, errfunc); - // if failed, drop the events - if (ret != 0) { - if (vm_is_hardfork(L, 4)) { - luaDropEvent(L, service, num_events); + // release the recovery point (on success) or revert the contract state (on error) + if (start_seq.r0 > 0) { + bool is_error = (ret != 0); + char *errStr = luaClearRecovery(L, start_seq.r0, is_error); + if (errStr != NULL) { + strPushAndRelease(L, errStr); + luaL_throwerror(L); } } @@ -215,6 +204,7 @@ static int xpcall(lua_State *L) { luaL_throwerror(L); } +/* // ensure the stack has 1 free slot if (!lua_checkstack(L, 1)) { // return: false, "stack overflow" @@ -223,24 +213,12 @@ static int xpcall(lua_State *L) { lua_pushliteral(L, "stack overflow"); return 2; } +*/ // store the status at the bottom of the stack, replacing the error handler lua_pushboolean(L, ret == 0); lua_replace(L, 1); - // release the recovery point or revert the contract state - if (start_seq.r0 > 0) { - bool is_error = (ret != 0); - char *errStr = luaClearRecovery(L, service, start_seq.r0, is_error); - if (errStr != NULL) { - if (vm_is_hardfork(L, 4)) { - luaDropEvent(L, service, num_events); - } - strPushAndRelease(L, errStr); - luaL_throwerror(L); - } - } - // return the number of items in the stack return lua_gettop(L); } @@ -278,25 +256,14 @@ lua_State *vm_newstate(int hardfork_version) { return L; } -void vm_closestates(lua_State *s[], int count) { - int i; - - for (i = 0; i < count; ++i) - if (s[i] != NULL) - lua_close(s[i]); -} - -void initViewFunction() { - lj_internal_view_start = vm_internal_view_start; - lj_internal_view_end = vm_internal_view_end; -} - bool vm_is_hardfork(lua_State *L, int version) { int v = luaL_hardforkversion(L); return v >= version; } -const char *vm_loadcall(lua_State *L) { +// execute code from the global scope, like declaring state variables and functions +// as well as abi.register, abi.register_view, abi.payable, etc. +const char *vm_pre_run(lua_State *L) { int err; if (lua_usegas(L)) { @@ -308,17 +275,15 @@ const char *vm_loadcall(lua_State *L) { } else { vm_set_count_hook(L, 5000000); } - luaL_enablemaxmem(L); } err = lua_pcall(L, 0, 0, 0); if (lua_usegas(L)) { lua_disablegas(L); - } else { - luaL_disablemaxmem(L); } + // remove hook lua_sethook(L, NULL, 0, 0); if (err != 0) { @@ -327,26 +292,15 @@ const char *vm_loadcall(lua_State *L) { return NULL; } -static int cp_getLuaExecContext(lua_State *L) { - int *service = (int *)lua_topointer(L, 1); - *service = getLuaExecContext(L); - return 0; -} - -const char *vm_copy_service(lua_State *L, lua_State *main) { - int service; - service = luaL_service(main); - if (service < 0) { - return "not permitted state referencing at global scope"; - } - luaL_set_service(L, service); - return NULL; -} - -const char *vm_loadbuff(lua_State *L, const char *code, size_t sz, char *hex_id, int service) { +// load the code into the Lua state +const char *vm_load_code(lua_State *L, const char *code, size_t sz, char *hex_id) { int err; - luaL_set_service(L, service); + // enable check for memory limit + luaL_enablemaxmem(L); + // mark as running on global scope + luaL_set_loading(L, true); + err = luaL_loadbuffer(L, code, sz, hex_id); if (err != 0) { return lua_tostring(L, -1); @@ -355,7 +309,13 @@ const char *vm_loadbuff(lua_State *L, const char *code, size_t sz, char *hex_id, return NULL; } -int vm_autoload(lua_State *L, char *fname) { +void vm_push_abi_function(lua_State *L, char *fname) { + lua_getfield(L, LUA_GLOBALSINDEX, "abi"); + lua_getfield(L, -1, "call"); + lua_pushstring(L, fname); +} + +int vm_push_global_function(lua_State *L, char *fname) { lua_getfield(L, LUA_GLOBALSINDEX, fname); return lua_isnil(L, -1) == 0; } @@ -365,72 +325,21 @@ void vm_remove_constructor(lua_State *L) { lua_setfield(L, LUA_GLOBALSINDEX, construct_name); } -static void count_hook(lua_State *L, lua_Debug *ar) { - luaL_setuncatchablerror(L); - lua_pushstring(L, "exceeded the maximum instruction count"); - luaL_throwerror(L); -} - -void vm_set_count_hook(lua_State *L, int limit) { - lua_sethook(L, count_hook, LUA_MASKCOUNT, limit); -} - -static void timeout_hook(lua_State *L, lua_Debug *ar) { - int errCode = luaCheckTimeout(luaL_service(L)); - if (errCode == 1) { - luaL_setuncatchablerror(L); - lua_pushstring(L, ERR_BF_TIMEOUT); - luaL_throwerror(L); - } else if (errCode == -1) { - luaL_error(L, "cannot find execution context"); - } -} - -void vm_set_timeout_hook(lua_State *L) { - if (vm_is_hardfork(L, 2)) { - lua_sethook(L, timeout_hook, LUA_MASKCOUNT, VM_TIMEOUT_INST_COUNT); - } -} - -static void timeout_count_hook(lua_State *L, lua_Debug *ar) { - int errCode; - int inst_count, new_inst_count, inst_limit; - - timeout_hook(L, ar); - - inst_count = luaL_tminstcount(L); - inst_limit = luaL_tminstlimit(L); - new_inst_count = inst_count + VM_TIMEOUT_INST_COUNT; - if (new_inst_count <= 0 || new_inst_count > inst_limit) { - luaL_setuncatchablerror(L); - lua_pushstring(L, "exceeded the maximum instruction count"); - luaL_throwerror(L); - } - luaL_set_tminstcount(L, new_inst_count); -} - -void vm_set_timeout_count_hook(lua_State *L, int limit) { - luaL_set_tminstlimit(L, limit); - luaL_set_tminstcount(L, 0); - lua_sethook(L, timeout_count_hook, LUA_MASKCOUNT, VM_TIMEOUT_INST_COUNT); -} - -const char *vm_pcall(lua_State *L, int argc, int *nresult) { +const char *vm_call(lua_State *L, int argc, int *nresult) { int err; int nr = lua_gettop(L) - argc - 1; + // mark as no longer loading, now running the function call + luaL_set_loading(L, false); + if (lua_usegas(L)) { lua_enablegas(L); - } else { - luaL_enablemaxmem(L); } err = lua_pcall(L, argc, LUA_MULTRET, 0); if (lua_usegas(L)) { lua_disablegas(L); - } else { - luaL_disablemaxmem(L); } if (err != 0) { @@ -441,16 +350,17 @@ const char *vm_pcall(lua_State *L, int argc, int *nresult) { if (err != 0) { return lua_tostring(L, -1); } + *nresult = lua_gettop(L) - nr; return NULL; } -const char *vm_get_json_ret(lua_State *L, int nresult, int *err) { +const char *vm_get_json_ret(lua_State *L, int nresult, bool has_parent, int *err) { int top; char *json_ret; top = lua_gettop(L); - json_ret = lua_util_get_json_from_stack(L, top - nresult + 1, top, true); + json_ret = lua_util_get_json_from_stack(L, top - nresult + 1, top, !has_parent); if (json_ret == NULL) { *err = 1; @@ -463,69 +373,22 @@ const char *vm_get_json_ret(lua_State *L, int nresult, int *err) { return lua_tostring(L, -1); } -const char *vm_copy_result(lua_State *L, lua_State *target, int cnt) { - int i; - int top; - char *json; - - if (lua_usegas(L)) { - lua_disablegas(target); - } else { - luaL_disablemaxmem(target); - } - - top = lua_gettop(L); - for (i = top - cnt + 1; i <= top; ++i) { - json = lua_util_get_json(L, i, false); - if (json == NULL) { - if (lua_usegas(L)) { - lua_enablegas(target); - } else { - luaL_enablemaxmem(target); - } - return lua_tostring(L, -1); - } - - minus_inst_count(L, strlen(json)); - lua_util_json_to_lua(target, json, false); - free(json); - } - - if (lua_usegas(L)) { - lua_enablegas(target); - } else { - luaL_enablemaxmem(target); - } - - return NULL; -} - -sqlite3 *vm_get_db(lua_State *L) { - int service; - sqlite3 *db; +// VIEW FUNCTIONS - service = getLuaExecContext(L); - db = luaGetDbHandle(service); - if (db == NULL) { - lua_pushstring(L, "can't open a database connection"); - luaL_throwerror(L); - } - return db; +void vm_internal_view_start(lua_State *L) { + luaViewStart(); } -void vm_get_abi_function(lua_State *L, char *fname) { - lua_getfield(L, LUA_GLOBALSINDEX, "abi"); - lua_getfield(L, -1, "call"); - lua_pushstring(L, fname); +void vm_internal_view_end(lua_State *L) { + luaViewEnd(); } -void vm_internal_view_start(lua_State *L) { - luaViewStart(getLuaExecContext(L)); +void initViewFunction() { + lj_internal_view_start = vm_internal_view_start; + lj_internal_view_end = vm_internal_view_end; } -void vm_internal_view_end(lua_State *L) { - luaViewEnd(getLuaExecContext(L)); -} +// INSTRUCTION COUNT int vm_instcount(lua_State *L) { if (lua_usegas(L)) { @@ -548,3 +411,61 @@ void vm_setinstcount(lua_State *L, int count) { luaL_setinstcount(L, count); } } + +// INSTRUCTION COUNT + +// this function is called at every N instructions +static void count_hook(lua_State *L, lua_Debug *ar) { + luaL_setuncatchablerror(L); + lua_pushstring(L, "exceeded the maximum instruction count"); + luaL_throwerror(L); +} + +// set instruction count hook +void vm_set_count_hook(lua_State *L, int limit) { + lua_sethook(L, count_hook, LUA_MASKCOUNT, limit); +} + +// TIMEOUT + +// this function is called at every N instructions +static void timeout_hook(lua_State *L, lua_Debug *ar) { + int errCode = luaCheckTimeout(); + if (errCode == 1) { + luaL_setuncatchablerror(L); + lua_pushstring(L, ERR_BF_TIMEOUT); + luaL_throwerror(L); + } +} + +// set timeout hook +void vm_set_timeout_hook(lua_State *L) { + lua_sethook(L, timeout_hook, LUA_MASKCOUNT, VM_TIMEOUT_INST_COUNT); +} + +// timeout and instruction count hook +// this function is called at every N instructions +static void timeout_count_hook(lua_State *L, lua_Debug *ar) { + int inst_count, new_inst_count, inst_limit; + + // check for timeout + timeout_hook(L, ar); + + // check instruction count + inst_count = luaL_tminstcount(L); + inst_limit = luaL_tminstlimit(L); + new_inst_count = inst_count + VM_TIMEOUT_INST_COUNT; + if (new_inst_count <= 0 || new_inst_count > inst_limit) { + luaL_setuncatchablerror(L); + lua_pushstring(L, "exceeded the maximum instruction count"); + luaL_throwerror(L); + } + luaL_set_tminstcount(L, new_inst_count); +} + +// set timeout and instruction count hook +void vm_set_timeout_count_hook(lua_State *L, int limit) { + luaL_set_tminstlimit(L, limit); + luaL_set_tminstcount(L, 0); + lua_sethook(L, timeout_count_hook, LUA_MASKCOUNT, VM_TIMEOUT_INST_COUNT); +} diff --git a/contract/vm/vm.go b/contract/vm/vm.go new file mode 100644 index 000000000..302124c2e --- /dev/null +++ b/contract/vm/vm.go @@ -0,0 +1,524 @@ +package main + +/* + #cgo CFLAGS: -I${SRCDIR}/../../libtool/include/luajit-2.1 -I${SRCDIR}/../../libtool/include + #cgo !windows CFLAGS: -DLJ_TARGET_POSIX + #cgo darwin LDFLAGS: ${SRCDIR}/../../libtool/lib/libluajit-5.1.a ${SRCDIR}/../../libtool/lib/libgmp.dylib -lm + #cgo windows LDFLAGS: ${SRCDIR}/../../libtool/lib/libluajit-5.1.a ${SRCDIR}/../../libtool/bin/libgmp-10.dll -lm + #cgo !darwin,!windows LDFLAGS: ${SRCDIR}/../../libtool/lib/libluajit-5.1.a -L${SRCDIR}/../../libtool/lib64 -L${SRCDIR}/../../libtool/lib -lgmp -lm + + + #include + #include + #include "vm.h" + #include "util.h" + #include "bignum_module.h" +*/ +import "C" +import ( + "errors" + "fmt" + //"reflect" + //"sort" + "strings" + "unsafe" + "encoding/binary" + + "github.com/aergoio/aergo-lib/log" + "github.com/aergoio/aergo/v2/cmd/aergoluac/luac" +) + +const vmTimeoutErrMsg = "contract timeout during vm execution" + +var logger *log.Logger + +type LState = C.lua_State // C.struct_lua_State + +var lstate *LState // *C.lua_State + +var contractAddress string +var contractCaller string +var contractGasLimit uint64 +var contractIsFeeDelegation bool + +//////////////////////////////////////////////////////////////////////////////// + +func InitializeVM() { + if lstate == nil { + // these are called only once on tests + logger = log.NewLogger("vm") + C.init_bignum() + C.initViewFunction() + } else { + C.lua_close(lstate) + } + lstate = C.vm_newstate(C.int(hardforkVersion)) +} + +//////////////////////////////////////////////////////////////////////////////// + +type executor struct { + L *LState + code []byte + fname string + args string + numArgs C.int + isAutoload bool + jsonRet string + err error + abiErr error +} + +func newExecutor(bytecode []byte, fname string, args string, abiError string) *executor { + + ce := &executor{ + L: lstate, + code: bytecode, + } + + if abiError != "" { + ce.abiErr = errors.New(abiError) + } + + // set the gas limit on the Lua state + setGas() + + // load the contract code into the Lua state + ce.vmLoadCode() + if ce.err != nil { + return ce + } + + // if fname starts with "autoload:" then it is an autoload function + if strings.HasPrefix(fname, "autoload:") { + ce.isAutoload = true + fname = fname[9:] + } + ce.fname = fname + ce.args = args + + return ce +} + +//////////////////////////////////////////////////////////////////////////////// +// Lua +//////////////////////////////////////////////////////////////////////////////// + +// push the arguments to the stack +func (ce *executor) pushArguments() { + args := C.CString(ce.args) + ce.numArgs = C.lua_util_json_array_to_lua(ce.L, args, C.bool(true)); + C.free(unsafe.Pointer(args)) + if ce.numArgs == -1 { + ce.err = errors.New("invalid arguments. must be valid JSON array") + } +} + +/* +func (ce *executor) processArgs() { + for _, v := range ce.ci.Args { + if err := pushValue(ce.L, v); err != nil { + ce.err = err + return + } + } +} + +func pushValue(L *LState, v interface{}) error { + switch arg := v.(type) { + case string: + argC := C.CBytes([]byte(arg)) + C.lua_pushlstring(L, (*C.char)(argC), C.size_t(len(arg))) + C.free(argC) + case float64: + if arg == float64(int64(arg)) { + C.lua_pushinteger(L, C.lua_Integer(arg)) + } else { + C.lua_pushnumber(L, C.double(arg)) + } + case bool: + var b int + if arg { + b = 1 + } + C.lua_pushboolean(L, C.int(b)) + case json.Number: + str := arg.String() + intVal, err := arg.Int64() + if err == nil { + C.lua_pushinteger(L, C.lua_Integer(intVal)) + } else { + ftVal, err := arg.Float64() + if err != nil { + return errors.New("unsupported number type:" + str) + } + C.lua_pushnumber(L, C.double(ftVal)) + } + case nil: + C.lua_pushnil(L) + case []interface{}: + err := toLuaArray(L, arg) + if err != nil { + return err + } + case map[string]interface{}: + err := toLuaTable(L, arg) + if err != nil { + return err + } + default: + return errors.New("unsupported type:" + reflect.TypeOf(v).Name()) + } + return nil +} + +func toLuaArray(L *LState, arr []interface{}) error { + C.lua_createtable(L, C.int(len(arr)), C.int(0)) + n := C.lua_gettop(L) + for i, v := range arr { + if err := pushValue(L, v); err != nil { + return err + } + C.lua_rawseti(L, n, C.int(i+1)) + } + return nil +} + +func toLuaTable(L *LState, tab map[string]interface{}) error { + C.lua_createtable(L, C.int(0), C.int(len(tab))) + n := C.lua_gettop(L) + // get the keys and sort them + keys := make([]string, 0, len(tab)) + for k := range tab { + keys = append(keys, k) + } + if C.vm_is_hardfork(L, 3) { + sort.Strings(keys) + } + for _, k := range keys { + v := tab[k] + if len(tab) == 1 && strings.EqualFold(k, "_bignum") { + if arg, ok := v.(string); ok { + C.lua_settop(L, -2) + argC := C.CString(arg) + msg := C.lua_set_bignum(L, argC) + C.free(unsafe.Pointer(argC)) + if msg != nil { + return errors.New(C.GoString(msg)) + } + return nil + } + } + // push a key + key := C.CString(k) + C.lua_pushstring(L, key) + C.free(unsafe.Pointer(key)) + + if err := pushValue(L, v); err != nil { + return err + } + C.lua_rawset(L, n) + } + return nil +} +*/ + +//////////////////////////////////////////////////////////////////////////////// + +func (ce *executor) call(hasParent bool) { + + defer func() { + if ce.err != nil && hasParent { + if bool(C.luaL_hasuncatchablerror(ce.L)) { + ce.err = errors.New("uncatchable: " + ce.err.Error()) + } + if bool(C.luaL_hassyserror(ce.L)) { + ce.err = errors.New("syserror: " + ce.err.Error()) + } + } + }() + + if ce.err != nil { + return + } + + // execute code from the global scope, like declaring state variables and functions + // as well as abi.register, abi.register_view, abi.payable, etc. + ce.vmPreRun() + if ce.err != nil { + return + } + // if there is no error in the code pre-execution but failed to process the ABI, return the error now + if ce.abiErr != nil { + ce.err = ce.abiErr + return + } + + // push the function to be called to the stack + if ce.isAutoload { + // used for constructor and check_delegation functions + loaded := vmPushGlobalFunction(ce.L, ce.fname) + if !loaded { + if ce.fname == "constructor" { + // the constructor function was not found + if hasParent { + ce.jsonRet = "[]" + } + } else { + ce.err = errors.New(fmt.Sprintf("contract autoload failed %s function: %s", + contractAddress, ce.fname)) + } + return + } + } else { + // used for normal functions + vmPushAbiFunction(ce.L, ce.fname) + } + + // push the arguments to the stack + ce.pushArguments() + if ce.err != nil { + logger.Debug().Err(ce.err).Str("contract", contractAddress).Msg("invalid argument") + return + } + if !ce.isAutoload { + ce.numArgs = ce.numArgs + 1 + } + + // call the function + nRet := C.int(0) + cErrMsg := C.vm_call(ce.L, ce.numArgs, &nRet) + + // check for errors + if cErrMsg != nil { + errMsg := C.GoString(cErrMsg) + if (errMsg == C.ERR_BF_TIMEOUT || errMsg == vmTimeoutErrMsg) { + ce.err = errors.New(vmTimeoutErrMsg) // &VmTimeoutError{} + } else { + if bool(C.luaL_hassyserror(ce.L)) { + errMsg = "syserror: " + errMsg + } + if bool(C.luaL_hasuncatchablerror(ce.L)) { + errMsg = "uncatchable: " + errMsg + } + ce.err = errors.New(errMsg) + } + logger.Debug().Err(ce.err).Str("contract", contractAddress).Msg("contract execution failed") + return + } + + // convert the result to json + var errRet C.int + retMsg := C.GoString(C.vm_get_json_ret(ce.L, nRet, C.bool(hasParent), &errRet)) + if errRet == 1 { + ce.err = errors.New(retMsg) + } else { + ce.jsonRet = retMsg + } + +/*/ this can be moved to server side + if ce.ctx.traceFile != nil { + // write the contract code to a file in the temp directory + address := types.EncodeAddress(ce.contractId) + codeFile := fmt.Sprintf("%s%s%s.code", os.TempDir(), string(os.PathSeparator), address) + if _, err := os.Stat(codeFile); os.IsNotExist(err) { + f, err := os.OpenFile(codeFile, os.O_WRONLY|os.O_CREATE, 0644) + if err == nil { + _, _ = f.Write(ce.code) + _ = f.Close() + } + } + // write the used fee to the trace file + str := fmt.Sprintf("contract %s used fee: %s\n", address, ce.ctx.usedFee().String()) + _, _ = ce.ctx.traceFile.WriteString(str) + } +*/ + +} + +// push the function to be called to the stack +func vmPushGlobalFunction(L *LState, funcName string) bool { + fname := C.CString(funcName) + loaded := C.vm_push_global_function(L, fname) + C.free(unsafe.Pointer(fname)) + return loaded != C.int(0) +} + +// push the function to be called to the stack +func vmPushAbiFunction(L *LState, funcName string) { + C.vm_remove_constructor(L) + fname := C.CString(funcName) + C.vm_push_abi_function(L, fname) + C.free(unsafe.Pointer(fname)) +} + +// load the contract code +func (ce *executor) vmLoadCode() { + chunkId := C.CString("@" + contractAddress) + defer C.free(unsafe.Pointer(chunkId)) + + cErrMsg := C.vm_load_code( + ce.L, + (*C.char)(unsafe.Pointer(&ce.code[0])), + C.size_t(len(ce.code)), + chunkId, + ) + + if cErrMsg != nil { + errMsg := C.GoString(cErrMsg) + ce.err = errors.New(errMsg) + logger.Debug().Err(ce.err).Str("contract", contractAddress).Msg("failed to load code") + } +} + +// execute code from the global scope, like declaring state variables and functions +// as well as abi.register, abi.register_view, abi.payable, etc. +func (ce *executor) vmPreRun() { + cErrMsg := C.vm_pre_run(ce.L) + if cErrMsg != nil { + errMsg := C.GoString(cErrMsg) + isUncatchable := bool(C.luaL_hasuncatchablerror(ce.L)) + if isUncatchable && (errMsg == C.ERR_BF_TIMEOUT || errMsg == vmTimeoutErrMsg) { + ce.err = errors.New(vmTimeoutErrMsg) // &VmTimeoutError{} + } else { + ce.err = errors.New(errMsg) + } + } +} + + +//////////////////////////////////////////////////////////////////////////////// +// GAS +//////////////////////////////////////////////////////////////////////////////// + + +func IsGasSystem() bool { + return contractGasLimit > 0 +} + +func setGas() { + if IsGasSystem() { + C.lua_gasset(lstate, C.ulonglong(contractGasLimit)) + } +} + +func getRemainingGas() uint64 { + return uint64(C.lua_gasget(lstate)) +} + +func getUsedGas() uint64 { + return contractGasLimit - getRemainingGas() +} + +func addConsumedGas(gas uint64, err error) error { + if !IsGasSystem() { + return err + } + remainingGas := getRemainingGas() + if gas > remainingGas { + if err == nil { + err = errors.New("uncatchable: gas limit exceeded") + } + return err + } + remainingGas -= gas + C.lua_gasset(lstate, C.ulonglong(remainingGas)) + return err +} + +// extract the used gas from the result +func extractUsedGas(result string) (uint64, string) { + if len(result) < 8 { + return 0, result + } + usedGas := binary.LittleEndian.Uint64([]byte(result[:8])) + result = result[8:] + return usedGas, result +} + + +//////////////////////////////////////////////////////////////////////////////// + +/* +func setInstCount(ctx *vmContext, parent *LState, child *LState) { + if !IsGasSystem() { + C.vm_setinstcount(parent, C.vm_instcount(child)) + } +} + +func setInstMinusCount(ctx *vmContext, L *LState, deduc C.int) { + if !IsGasSystem() { + C.vm_setinstcount(L, minusCallCount(ctx, C.vm_instcount(L), deduc)) + } +} + +func minusCallCount(ctx *vmContext, curCount, deduc C.int) C.int { + if !IsGasSystem() { + return 0 + } + remain := curCount - deduc + if remain <= 0 { + remain = 1 + } + return remain +} +*/ + +//////////////////////////////////////////////////////////////////////////////// + +func Execute( + address string, + code string, + fname string, + args string, + gas uint64, + caller string, + hasParent bool, + isFeeDelegation bool, + abiError string, +) (string, error, uint64) { + + contractAddress = address + contractCaller = caller + contractGasLimit = gas + contractIsFeeDelegation = isFeeDelegation + + ex := newExecutor([]byte(code), fname, args, abiError) + + ex.call(hasParent) + + totalUsedGas := getUsedGas() + + return ex.jsonRet, ex.err, totalUsedGas +} + +//////////////////////////////////////////////////////////////////////////////// + +func Compile(code string, hasParent bool) ([]byte, error) { + L := luac.NewLState() + if L == nil { + return nil, errors.New("syserror: failed to create LState") + } + defer luac.CloseLState(L) + var lState = (*LState)(L) + + if hasParent { + // mark as running a call + C.luaL_set_loading(lState, C.bool(false)) + // set the hardfork version + //C.luaL_set_hardforkversion(lState, 5) + // set the timeout hook + C.vm_set_timeout_hook(lState) + } + + byteCodeAbi, err := luac.Compile(L, code) + if err != nil { + // if there is an uncatchable error, return it to the parent + if hasParent && bool(C.luaL_hasuncatchablerror(lState)) { + err = errors.New("uncatchable: " + err.Error()) + } + return nil, err + } + + return byteCodeAbi.Bytes(), nil +} diff --git a/contract/vm/vm.h b/contract/vm/vm.h new file mode 100644 index 000000000..7d6cb657c --- /dev/null +++ b/contract/vm/vm.h @@ -0,0 +1,31 @@ +#ifndef _VM_H +#define _VM_H + +#include +#include +#include +#include +#include + +extern const char *construct_name; + +#define ERR_BF_TIMEOUT "contract timeout" + +void initViewFunction(); +lua_State *vm_newstate(int hardfork_version); +void vm_push_abi_function(lua_State *L, char *fname); +int vm_push_global_function(lua_State *L, char *fname); +void vm_remove_constructor(lua_State *L); +const char *vm_load_code(lua_State *L, const char *code, size_t sz, char *hex_id); +const char *vm_pre_run(lua_State *L); +const char *vm_call(lua_State *L, int argc, int *nresult); +const char *vm_get_json_ret(lua_State *L, int nresult, bool has_parent, int *err); +void vm_set_count_hook(lua_State *L, int limit); +void vm_db_release_resource(lua_State *L); +bool vm_is_hardfork(lua_State *L, int version); +void vm_set_timeout_hook(lua_State *L); +void vm_set_timeout_count_hook(lua_State *L, int limit); +int vm_instcount(lua_State *L); +void vm_setinstcount(lua_State *L, int count); + +#endif /* _VM_H */ diff --git a/contract/vm/vm_callback.go b/contract/vm/vm_callback.go new file mode 100644 index 000000000..b9ef8b621 --- /dev/null +++ b/contract/vm/vm_callback.go @@ -0,0 +1,664 @@ +package main + +/* +#cgo CFLAGS: -I${SRCDIR}/../../libtool/include/luajit-2.1 +#cgo LDFLAGS: ${SRCDIR}/../../libtool/lib/libluajit-5.1.a -lm + +#include +#include +#include "vm.h" +#include "util.h" +#include "db_module.h" +#include "../db_msg.h" +#include "bignum_module.h" + +struct proof { + void *data; + size_t len; +}; + +#define RLP_TSTRING 0 +#define RLP_TLIST 1 + +struct rlp_obj { + int rlp_obj_type; + void *data; + size_t size; +}; +*/ +import "C" +import ( + "fmt" + "strconv" + "strings" + "unsafe" + "errors" + "time" + + "github.com/aergoio/aergo/v2/internal/enc/hex" + "github.com/aergoio/aergo/v2/contract/msg" +) + +var nestedView int + +//export luaViewStart +func luaViewStart() { + nestedView++ +} + +//export luaViewEnd +func luaViewEnd() { + nestedView-- +} + +//export luaSetVariable +func luaSetVariable(L *LState, key *C.char, keyLen C.int, value *C.char) *C.char { + args := []string{C.GoStringN(key, keyLen), C.GoString(value)} + _, err := sendRequest("set", args) + if err != nil { + return handleError(L, err) + } + return nil +} + +//export luaGetVariable +func luaGetVariable(L *LState, key *C.char, keyLen C.int, blkno *C.char) (*C.char, *C.char) { + args := []string{C.GoStringN(key, keyLen), C.GoString(blkno)} + result, err := sendRequest("get", args) + if err != nil { + return nil, handleError(L, err) + } + if len(result) > 0 { + return C.CString(result), nil + } + return nil, nil +} + +//export luaDelVariable +func luaDelVariable(L *LState, key *C.char, keyLen C.int) *C.char { + args := []string{C.GoStringN(key, keyLen)} + _, err := sendRequest("del", args) + if err != nil { + return handleError(L, err) + } + return nil +} + +//export luaCallContract +func luaCallContract(L *LState, + address *C.char, fname *C.char, arguments *C.char, + amount *C.char, gas uint64, +) (*C.char, *C.char) { + + contractAddress := C.GoString(address) + fnameStr := C.GoString(fname) + argsStr := C.GoString(arguments) + amountStr := C.GoString(amount) + gasLimit := getGasLimit(gas) + gasStr := string((*[8]byte)(unsafe.Pointer(&gasLimit))[:]) + + args := []string{contractAddress, fnameStr, argsStr, amountStr, gasStr} + result, err := sendRequest("call", args) + + // extract the used gas from the result + usedGas, result := extractUsedGas(result) + // update the remaining gas + err = addConsumedGas(usedGas, err) + + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaDelegateCallContract +func luaDelegateCallContract(L *LState, + address *C.char, fname *C.char, arguments *C.char, gas uint64, +) (*C.char, *C.char) { + + contractAddress := C.GoString(address) + fnameStr := C.GoString(fname) + argsStr := C.GoString(arguments) + gasLimit := getGasLimit(gas) + gasStr := string((*[8]byte)(unsafe.Pointer(&gasLimit))[:]) + + args := []string{contractAddress, fnameStr, argsStr, gasStr} + result, err := sendRequest("delegate-call", args) + + // extract the used gas from the result + usedGas, result := extractUsedGas(result) + // update the remaining gas + err = addConsumedGas(usedGas, err) + + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaSendAmount +func luaSendAmount(L *LState, address *C.char, amount *C.char) *C.char { + gasLimit := getGasLimit(0) + gasStr := string((*[8]byte)(unsafe.Pointer(&gasLimit))[:]) + args := []string{C.GoString(address), C.GoString(amount), gasStr} + result, err := sendRequest("send", args) + + // extract the used gas from the result + usedGas, result := extractUsedGas(result) + // update the remaining gas + err = addConsumedGas(usedGas, err) + + if err != nil { + return handleError(L, err) + } + // it does not return the result + return nil +} + +//export luaPrint +func luaPrint(L *LState, arguments *C.char) *C.char { + args := []string{C.GoString(arguments)} + _, err := sendRequest("print", args) + if err != nil { + return handleError(L, err) + } + return nil +} + +//export luaSetRecoveryPoint +func luaSetRecoveryPoint(L *LState) (C.int, *C.char) { + args := []string{} + result, err := sendRequest("setRecoveryPoint", args) + if err != nil { + return -1, handleError(L, err) + } + // if on a query or inside a view function + if result == "" { + return 0, nil + } + resultInt, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return -1, handleError(L, fmt.Errorf("uncatchable: luaSetRecoveryPoint: failed to parse result: %v", err)) + } + return C.int(resultInt), nil +} + +//export luaClearRecovery +func luaClearRecovery(L *LState, start int, isError bool) *C.char { + args := []string{fmt.Sprintf("%d", start), fmt.Sprintf("%t", isError)} + _, err := sendRequest("clearRecovery", args) + if err != nil { + return handleError(L, err) + } + return nil +} + +//export luaGetBalance +func luaGetBalance(L *LState, address *C.char) (*C.char, *C.char) { + args := []string{C.GoString(address)} + result, err := sendRequest("balance", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaGetSender +func luaGetSender(L *LState) *C.char { + return C.CString(contractCaller) +} + +//export luaGetTxHash +func luaGetTxHash(L *LState) (*C.char, *C.char) { + args := []string{} + result, err := sendRequest("getTxHash", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaGetBlockNo +func luaGetBlockNo(L *LState) (C.lua_Integer, *C.char) { + args := []string{} + result, err := sendRequest("getBlockNo", args) + if err != nil { + return C.lua_Integer(0), handleError(L, err) + } + blockNo, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return C.lua_Integer(0), handleError(L, fmt.Errorf("uncatchable: luaGetBlockNo: failed to parse result: %v", err)) + } + return C.lua_Integer(blockNo), nil +} + +//export luaGetTimeStamp +func luaGetTimeStamp(L *LState) (C.lua_Integer, *C.char) { + args := []string{} + result, err := sendRequest("getTimeStamp", args) + if err != nil { + return C.lua_Integer(0), handleError(L, err) + } + timestamp, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return C.lua_Integer(0), handleError(L, fmt.Errorf("uncatchable: failed to parse timestamp: %v", err)) + } + return C.lua_Integer(timestamp), nil +} + +//export luaGetContractId +func luaGetContractId(L *LState) *C.char { + return C.CString(contractAddress) +} + +//export luaGetAmount +func luaGetAmount(L *LState) (*C.char, *C.char) { + args := []string{} + result, err := sendRequest("getAmount", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaGetOrigin +func luaGetOrigin(L *LState) (*C.char, *C.char) { + args := []string{} + result, err := sendRequest("getOrigin", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaGetPrevBlockHash +func luaGetPrevBlockHash(L *LState) (*C.char, *C.char) { + args := []string{} + result, err := sendRequest("getPrevBlockHash", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +func checkHexString(data string) bool { + if len(data) >= 2 && data[0] == '0' && (data[1] == 'x' || data[1] == 'X') { + return true + } + return false +} + +//export luaCryptoSha256 +func luaCryptoSha256(L *LState, arg unsafe.Pointer, argLen C.int) (*C.char, *C.char) { + data := C.GoBytes(arg, argLen) + args := []string{string(data)} + result, err := sendRequest("sha256", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +func decodeHex(hexStr string) ([]byte, error) { + if checkHexString(hexStr) { + hexStr = hexStr[2:] + } + return hex.Decode(hexStr) +} + +//export luaECVerify +func luaECVerify(L *LState, msg *C.char, sig *C.char, addr *C.char) (C.int, *C.char) { + args := []string{C.GoString(msg), C.GoString(sig), C.GoString(addr)} + result, err := sendRequest("ecVerify", args) + if err != nil { + return C.int(-1), handleError(L, err) + } + resultInt, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return C.int(-1), handleError(L, fmt.Errorf("uncatchable: luaECVerify: failed to parse result: %v", err)) + } + return C.int(resultInt), nil +} + +func luaCryptoToBytes(data unsafe.Pointer, dataLen C.int) ([]byte, bool) { + var d []byte + b := C.GoBytes(data, dataLen) + isHex := checkHexString(string(b)) + if isHex { + var err error + d, err = hex.Decode(string(b[2:])) + if err != nil { + isHex = false + } + } + if !isHex { + d = b + } + return d, isHex +} + +func luaCryptoRlpToBytes(data unsafe.Pointer) []byte { + x := (*C.struct_rlp_obj)(data) + if x.rlp_obj_type == C.RLP_TSTRING { + b, _ := luaCryptoToBytes(x.data, C.int(x.size)) + // add a first byte to the byte array to indicate the type of the RLP object + b = append([]byte{byte(C.RLP_TSTRING)}, b...) + return b + } + elems := (*[1 << 30]C.struct_rlp_obj)(unsafe.Pointer(x.data))[:C.int(x.size):C.int(x.size)] + list := make([][]byte, len(elems)) + for i, elem := range elems { + b, _ := luaCryptoToBytes(elem.data, C.int(elem.size)) + list[i] = b + } + // serialize the list as a single byte array, including the type byte + ret := msg.SerializeMessageBytes(append([][]byte{[]byte{byte(C.RLP_TLIST)}}, list...)...) + return ret +} + +//export luaCryptoVerifyProof +func luaCryptoVerifyProof( + L *LState, + key unsafe.Pointer, keyLen C.int, + value unsafe.Pointer, + hash unsafe.Pointer, hashLen C.int, + proof unsafe.Pointer, nProof C.int, +) (C.int, *C.char) { + // convert to bytes + k, _ := luaCryptoToBytes(key, keyLen) + v := luaCryptoRlpToBytes(value) + h, _ := luaCryptoToBytes(hash, hashLen) + // read each proof element into a string array + cProof := (*[1 << 30]C.struct_proof)(proof)[:nProof:nProof] + proofElems := make([]string, int(nProof)) + for i, p := range cProof { + data, _ := luaCryptoToBytes(p.data, C.int(p.len)) + proofElems[i] = string(data) + } + // convert the proof elements into a single byte array + proofBytes := msg.SerializeMessage(proofElems...) + + // send request + args := []string{string(k), string(v), string(h), string(proofBytes)} + result, err := sendRequest("verifyEthStorageProof", args) + if err != nil { + return C.int(0), handleError(L, err) + } + resultInt, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return C.int(0), handleError(L, fmt.Errorf("uncatchable: luaCryptoVerifyProof: failed to parse result: %v", err)) + } + return C.int(resultInt), nil +} + +//export luaCryptoKeccak256 +func luaCryptoKeccak256(L *LState, data *C.char, dataLen C.int) (unsafe.Pointer, C.int, *C.char) { + args := []string{C.GoStringN(data, dataLen)} + result, err := sendRequest("keccak256", args) + if err != nil { + return nil, 0, handleError(L, err) + } + return C.CBytes([]byte(result)), C.int(len(result)), nil +} + +//export luaDeployContract +func luaDeployContract( + L *LState, + contract *C.char, + arguments *C.char, + amount *C.char, +) (*C.char, *C.char) { + + contractStr := C.GoString(contract) + argsStr := C.GoString(arguments) + amountStr := C.GoString(amount) + gasLimit := getGasLimit(0) + gasStr := string((*[8]byte)(unsafe.Pointer(&gasLimit))[:]) + + args := []string{contractStr, argsStr, amountStr, gasStr} + result, err := sendRequest("deploy", args) + + // extract the used gas from the result + usedGas, result := extractUsedGas(result) + // update the remaining gas + err = addConsumedGas(usedGas, err) + + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export isPublic +func isPublic() C.int { + if isPubNet { + return C.int(1) + } else { + return C.int(0) + } +} + +//export luaRandomInt +func luaRandomInt(L *LState, min, max C.int) (C.int, *C.char) { + args := []string{fmt.Sprintf("%d", min), fmt.Sprintf("%d", max)} + result, err := sendRequest("randomInt", args) + if err != nil { + return C.int(0), handleError(L, err) + } + value, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return C.int(0), handleError(L, fmt.Errorf("uncatchable: luaRandomInt: failed to parse result: %v", err)) + } + return C.int(value), nil +} + +//export luaEvent +func luaEvent(L *LState, eventName *C.char, arguments *C.char) *C.char { + args := []string{C.GoString(eventName), C.GoString(arguments)} + _, err := sendRequest("event", args) + if err != nil { + return handleError(L, err) + } + return nil +} + +//export luaToPubkey +func luaToPubkey(L *LState, address *C.char) (*C.char, *C.char) { + args := []string{C.GoString(address)} + result, err := sendRequest("toPubkey", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaToAddress +func luaToAddress(L *LState, pubkey *C.char) (*C.char, *C.char) { + args := []string{C.GoString(pubkey)} + result, err := sendRequest("toAddress", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaIsContract +func luaIsContract(L *LState, address *C.char) (C.int, *C.char) { + args := []string{C.GoString(address)} + result, err := sendRequest("isContract", args) + if err != nil { + return -1, handleError(L, err) + } + resultInt, err := strconv.ParseInt(result, 10, 64) + if err != nil { + return -1, handleError(L, fmt.Errorf("uncatchable: luaIsContract: failed to parse result: %v", err)) + } + return C.int(resultInt), nil +} + +//export luaNameResolve +func luaNameResolve(L *LState, name_or_address *C.char) (*C.char, *C.char) { + args := []string{C.GoString(name_or_address)} + result, err := sendRequest("nameResolve", args) + if err != nil { + return nil, handleError(L, err) + } + return C.CString(result), nil +} + +//export luaGovernance +func luaGovernance(L *LState, gType C.char, arg *C.char) *C.char { + args := []string{fmt.Sprintf("%c", gType), C.GoString(arg)} + _, err := sendRequest("governance", args) + if err != nil { + return handleError(L, err) + } + return nil +} + + + +// checks whether the block creation timeout occurred +// +//export luaCheckTimeout +func luaCheckTimeout() C.int { + if timedout { + return 1 + } + return 0 +} + +//export luaIsFeeDelegation +func luaIsFeeDelegation() C.bool { + return C.bool(contractIsFeeDelegation) +} + +//export luaGetStaking +func luaGetStaking(L *LState, addr *C.char) (*C.char, C.lua_Integer, *C.char) { + args := []string{C.GoString(addr)} + result, err := sendRequest("getStaking", args) + if err != nil { + return nil, 0, handleError(L, err) + } + // extract amount and when from result - result = staking.GetAmountBigInt().String() + "," + staking.When.String() + sep := strings.Index(result, ",") + amount := result[:sep] + when, err := strconv.ParseInt(result[sep+1:], 10, 64) + if err != nil { + return nil, 0, handleError(L, fmt.Errorf("uncatchable: luaGetStaking: failed to parse 'when': %v", err)) + } + return C.CString(amount), C.lua_Integer(when), nil +} + + +func getGasLimit(definedGasLimit uint64) uint64 { + + remainingGas := getRemainingGas() + + if definedGasLimit > 0 && definedGasLimit < remainingGas { + // if specified via contract.call.gas(limit)(...) + return definedGasLimit + } else { + // if not specified, use the remaining gas from the transaction + return remainingGas + } +} + +//export luaSendRequest +func luaSendRequest(L *LState, method *C.char, arguments *C.buffer, response *C.rresponse) { + var args []string + if arguments != nil { + args = []string{C.GoStringN(arguments.ptr, arguments.len)} + } else { + args = []string{} + } + result, err := sendRequest(C.GoString(method), args) + if err != nil { + response.error = handleError(L, err) + } else { + response.result.ptr = C.CString(result) + response.result.len = C.int(len(result)) + } +} + +var sendRequest = sendRequestFunc + +func sendRequestFunc(method string, args []string) (string, error) { + + // send the execution request to the VM instance + err := sendApiMessage(method, args) + if err != nil { + return "", err + } + + // wait for the response + response, err := msg.WaitForMessage(conn, time.Time{}) + if err != nil { + return "", err //FIXME: this is a system error + } + + /*/ decrypt the message + response, err = msg.Decrypt(response, secretKey) + if err != nil { + return "", err + } + */ + + list, err := msg.DeserializeMessage(response) + if err != nil { + return "", err + } + + result := list[0] + errstr := list[1] + if errstr != "" { + err = errors.New(errstr) + } + + // return the result + return result, err +} + +func sendApiMessage(method string, args []string) error { + + inViewStr := "0" + if nestedView > 0 { + inViewStr = "1" + } + + // create new slice with the method, args and whether it is within a view function + list := []string{method} + list = append(list, args...) + list = append(list, inViewStr) + + return sendMessage(list) +} + +func sendMessage(list []string) error { + + // build the message + message := msg.SerializeMessage(list...) + + /*/ encrypt the message + message, err = msg.Encrypt(message, secretKey) + if err != nil { + fmt.Printf("Error: failed to encrypt message: %v\n", err) + closeApp(1) + } */ + + // send the message to the VM API + return msg.SendMessage(conn, message) +} + +func handleError(L *LState, err error) *C.char { + errstr := err.Error() + if strings.HasPrefix(errstr, "uncatchable: ") { + errstr = errstr[len("uncatchable: "):] + C.luaL_setuncatchablerror(L) + } + if strings.HasPrefix(errstr, "syserror: ") { + errstr = errstr[len("syserror: "):] + C.luaL_setsyserror(L) + } + return C.CString(errstr) +} diff --git a/contract/vm/vm_test.go b/contract/vm/vm_test.go new file mode 100644 index 000000000..bee986864 --- /dev/null +++ b/contract/vm/vm_test.go @@ -0,0 +1,529 @@ +package main + +import ( + "testing" + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/aergoio/aergo/v2/cmd/aergoluac/util" +) + +func TestCompile(t *testing.T) { + + invalidCode := ` + function add(a, b) + return a + b + end + ` + + validCode := ` + function add(a, b) + return a + b + end + abi.register(add) + ` + + // Test case: Valid Lua code + byteCode, err := Compile(validCode, false) + assert.NoError(t, err, "Expected no error for valid Lua code") + assert.NotNil(t, byteCode, "Expected bytecode for valid Lua code") + + // Test case: Invalid Lua code + byteCode, err = Compile(invalidCode, false) + assert.Error(t, err, "Expected an error for invalid Lua code") + assert.Nil(t, byteCode, "Expected no bytecode for invalid Lua code") + + // Test case: Valid Lua code with parent + byteCode, err = Compile(validCode, true) + assert.NoError(t, err, "Expected no error for valid Lua code with parent") + assert.NotNil(t, byteCode, "Expected bytecode for valid Lua code with parent") + + // Test case: Invalid Lua code with parent + byteCode, err = Compile(invalidCode, true) + assert.Error(t, err, "Expected an error for invalid Lua code with parent") + assert.Nil(t, byteCode, "Expected no bytecode for invalid Lua code with parent") +} + +func TestExecuteBasic(t *testing.T) { + + contractCode := ` + function add(a, b) + return a + b + end + function hello(name) + return "Hello, " .. name + end + function many() + return 123, bignum.number(456), "abc", true, nil + end + function echo(...) + return ... + end + abi.register(add, hello, many, echo) + ` + + // set global variables + hardforkVersion = 3 + isPubNet = true + + // initialize the Lua VM + InitializeVM() + + // compile contract + byteCodeAbi, err := Compile(contractCode, false) + assert.NoError(t, err) + assert.NotNil(t, byteCodeAbi) + + bytecode := util.LuaCode(byteCodeAbi).ByteCode() + + // execute contract - add + result, err, usedGas := Execute("testAddress", string(bytecode), "add", `[1,2]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `3`, result) + + // execute contract - hello + result, err, usedGas = Execute("testAddress", string(bytecode), "hello", `["World"]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `"Hello, World"`, result) + + // execute contract - many + result, err, usedGas = Execute("testAddress", string(bytecode), "many", `[]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `[123,{"_bignum":"456"},"abc",true,null]`, result) + + // execute contract - echo + result, err, usedGas = Execute("testAddress", string(bytecode), "echo", `[123,4.56,{"_bignum":"789"},"abc",true,null]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `[123,4.56,{"_bignum":"789"},"abc",true,null]`, result) + +} + +func TestExecuteQueryBasic(t *testing.T) { + + contractCode := ` + function add(a, b) + return a + b + end + function hello(name) + return "Hello, " .. name + end + function many() + return 123, bignum.number(456), "abc", true, nil + end + function echo(...) + return ... + end + abi.register(add, hello, many, echo) + ` + + // set global variables + hardforkVersion = 3 + isPubNet = true + + // initialize the Lua VM + InitializeVM() + + // compile contract + byteCodeAbi, err := Compile(contractCode, false) + assert.NoError(t, err) + assert.NotNil(t, byteCodeAbi) + + bytecode := util.LuaCode(byteCodeAbi).ByteCode() + + // execute contract - add + result, err, usedGas := Execute("testAddress", string(bytecode), "add", `[1,2]`, 0, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Equal(t, usedGas, uint64(0), "Expected no gas to be used") + assert.Equal(t, `3`, result) + + // execute contract - hello + result, err, usedGas = Execute("testAddress", string(bytecode), "hello", `["World"]`, 0, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Equal(t, usedGas, uint64(0), "Expected no gas to be used") + assert.Equal(t, `"Hello, World"`, result) + + // execute contract - many + result, err, usedGas = Execute("testAddress", string(bytecode), "many", `[]`, 0, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Equal(t, usedGas, uint64(0), "Expected no gas to be used") + assert.Equal(t, `[123,{"_bignum":"456"},"abc",true,null]`, result) + + // execute contract - echo + result, err, usedGas = Execute("testAddress", string(bytecode), "echo", `[123,4.56,{"_bignum":"789"},"abc",true,null]`, 0, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Equal(t, usedGas, uint64(0), "Expected no gas to be used") + assert.Equal(t, `[123,4.56,{"_bignum":"789"},"abc",true,null]`, result) + +} + + +type vmCallback struct { + method string + args []string + result string + err error +} + +var callbacks []vmCallback + +func TestExecuteWithCallback(t *testing.T) { + + sendRequest = func(method string, args []string) (string, error) { + //fmt.Println("method: ", method, "args: ", args) + // get the next callback + callback := callbacks[0] + callbacks = callbacks[1:] + // check that the method and args are correct + assert.Equal(t, callback.method, method) + assert.Equal(t, callback.args, args) + return callback.result, callback.err + } + + contractCode := ` + state.var { + kv = state.map() + } + function set(key, value) + kv[key] = value + end + function get(key) + return kv[key] + end + function send(to, amount) + return contract.send(to, amount) + end + function call(...) + return contract.call(...) + end + function call_with_send(amount, ...) + return contract.call.value(amount)(...) + end + function delegatecall(...) + return contract.delegatecall(...) + end + function deploy(...) + return contract.deploy(...) + end + function deploy_with_send(amount, ...) + return contract.deploy.value(amount)(...) + end + function get_info() + return system.getContractID(), contract.balance(), system.getAmount(), system.getSender(), system.getOrigin(), system.isFeeDelegation() + end + function events() + contract.event('first', 123, 'abc') + contract.event('second', '456', 7.89) + end + abi.register(set, get, send, call, call_with_send, delegatecall, deploy, deploy_with_send, get_info, events) + ` + + contract2 := ` + state.var { + _owner = state.value(), + _name = state.value() + } + function default() + -- do nothing, only receive aergo + end + function constructor(first_name) + _name.set(first_name) + _owner.set(contract.getSender()) + end + abi.payable(constructor, default) + ` + + contract3 := ` + function sql_func() + local rt = {} + local rs = db.query("select round(3.14),min(1,2,3), max(4,5,6)") + if rs:next() then + local col1, col2, col3 = rs:get() + table.insert(rt, col1) + table.insert(rt, col2) + table.insert(rt, col3) + return rt + else + return "error in func()" + end + end + abi.register(sql_func) + ` + + // set global variables + hardforkVersion = 3 + isPubNet = true + + // initialize the Lua VM + InitializeVM() + + // compile contract + byteCodeAbi, err := Compile(contractCode, false) + assert.NoError(t, err) + assert.NotNil(t, byteCodeAbi) + + bytecode := util.LuaCode(byteCodeAbi).ByteCode() + + InitializeVM() + + // execute contract - set + callbacks = []vmCallback{ + {"get", []string{"_sv_meta-type_kv", ""}, "null", nil}, + {"set", []string{"_sv_meta-type_kv", "4"}, "", nil}, + {"set", []string{"_sv_kv-key", "12345"}, "", nil}, + } + result, err, usedGas := Execute("testAddress", string(bytecode), "set", `["key",12345]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, ``, result) + + InitializeVM() + + // execute contract - get + callbacks = []vmCallback{ + {"get", []string{"_sv_meta-type_kv", ""}, "4", nil}, + {"get", []string{"_sv_kv-key", ""}, `12345`, nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "get", `["key"]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `12345`, result) + + InitializeVM() + + // execute contract - send - simple + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"send", []string{"0x12345", "1000000000000000000", "\xa8*\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "send", `["0x12345","1000000000000000000"]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, ``, result) + + InitializeVM() + + // execute contract - send - with successful call + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"send", []string{"0x12345", "1000000000000000000", "\xa8*\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "send", `["0x12345","1000000000000000000"]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, ``, result) + + InitializeVM() + + // execute contract - send - with failed call + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"send", []string{"0x12345", "1000000000000000000", "\xa8*\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00", errors.New("failed call")}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "send", `["0x12345","1000000000000000000"]`, 1000000, "testCaller", false, false, "") + assert.Error(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + + InitializeVM() + + // execute contract - send - with invalid address + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"send", []string{"chucku-chucku", "1000000000000000000", "\xa8*\x0f\x00\x00\x00\x00\x00"}, "", errors.New("invalid address")}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "send", `["chucku-chucku","1000000000000000000"]`, 1000000, "testCaller", false, false, "") + assert.Error(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + + InitializeVM() + + // execute contract - send - with invalid amount + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"send", []string{"0x12345", "abc", "\xa8*\x0f\x00\x00\x00\x00\x00"}, "", errors.New("invalid amount")}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "send", `["0x12345","abc"]`, 1000000, "testCaller", false, false, "") + assert.Error(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + + InitializeVM() + + // execute contract - call + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"call", []string{"0x12345", "add", "[1,2]", "", ".#\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[3]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "call", `["0x12345","add",1,2]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `3`, result) + + InitializeVM() + + // execute contract - call with send + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"call", []string{"0x12345", "buy", `[1,"NFT"]`, "9876543210", "\xf8\x1d\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[\"purchased\",1,\"NFT\"]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "call_with_send", `["9876543210","0x12345","buy",1,"NFT"]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `["purchased",1,"NFT"]`, result) + + InitializeVM() + + // execute contract - call with send to default function + callbacks = []vmCallback{ + // the last argument is the gas in bytes, the first 8 bytes of the result is the used gas + {"send", []string{"0x12345", "9876543210", "\xca)\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "call_with_send", `["9876543210","0x12345"]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + + InitializeVM() + + // execute contract - delegated call + callbacks = []vmCallback{ + {"delegate-call", []string{"0x12345", "add", "[1,2]", "\x9d#\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[3]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "delegatecall", `["0x12345","add",1,2]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `3`, result) + + InitializeVM() + + // execute contract - deploy + callbacks = []vmCallback{ + {"deploy", []string{contractCode, "[]", "", "\xd5\x17\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "deploy", `["`+contractCode+`"]`, 1000005, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, ``, result) + + InitializeVM() + + // execute contract - deploy with invalid return + callbacks = []vmCallback{ + {"deploy", []string{contractCode, "[]", "", "\xd5\x17\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00...", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "deploy", `["`+contractCode+`"]`, 1000005, "testCaller", false, false, "") + assert.Error(t, err) + assert.Equal(t, err.Error(), "uncatchable: internal error: result from call is not a valid JSON array") + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + + InitializeVM() + + // execute contract - deploy with send + callbacks = []vmCallback{ + {"deploy", []string{contract2, "[250]", "9876543210", "|\x11\x0f\x00\x00\x00\x00\x00"}, "\x09\x00\x01\x00\x00\x00\x00\x00[]", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "deploy_with_send", `["9876543210","`+contract2+`",250]`, 1000005, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, ``, result) + + InitializeVM() + + // execute contract - get_info + callbacks = []vmCallback{ + {"balance", []string{""}, "123000", nil}, + {"getAmount", []string{}, "1000000", nil}, + {"getOrigin", []string{}, "anotherAddress", nil}, + {"isFeeDelegation", []string{}, "false", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "get_info", `[]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `["testAddress","123000","1000000","testCaller","anotherAddress",false]`, result) + + InitializeVM() + + // execute contract - events + callbacks = []vmCallback{ + {"event", []string{"first", `[123,"abc"]`}, "", nil}, + {"event", []string{"second", `["456",7.89]`}, "", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "events", `[]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.Empty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + + + + isPubNet = false + + InitializeVM() + + // compile contract + byteCodeAbi, err = Compile(contract3, false) + assert.NoError(t, err) + assert.NotNil(t, byteCodeAbi) + + bytecode = util.LuaCode(byteCodeAbi).ByteCode() + + InitializeVM() + + // execute contract - sql_func + callbacks = []vmCallback{ + {"dbQuery", []string{"+\x00\x00\x00sselect round(3.14),min(1,2,3), max(4,5,6)\x00\x01\x00\x00\x00y"}, "\x05\x00\x00\x00i\x01\x00\x00\x00", nil}, + {"rsNext", []string{"\x05\x00\x00\x00i\x01\x00\x00\x00"}, "\x02\x00\x00\x00b\x01", nil}, + {"rsGet", []string{"\x05\x00\x00\x00i\x01\x00\x00\x00"}, "\x05\x00\x00\x00i\x03\x00\x00\x00\x05\x00\x00\x00i\x01\x00\x00\x00\x05\x00\x00\x00i\x06\x00\x00\x00", nil}, + //{"rsNext", []string{"\x05\x00\x00\x00i\x01\x00\x00\x00"}, "\x02\x00\x00\x00b\x00", nil}, + } + result, err, usedGas = Execute("testAddress", string(bytecode), "sql_func", `[]`, 1000000, "testCaller", false, false, "") + assert.NoError(t, err) + assert.NotEmpty(t, result) + assert.Greater(t, usedGas, uint64(0), "Expected some gas to be used") + fmt.Println("used gas: ", usedGas) + assert.Equal(t, `[3,1,6]`, result) + +} diff --git a/contract/vm_api.go b/contract/vm_api.go new file mode 100644 index 000000000..d0157614b --- /dev/null +++ b/contract/vm_api.go @@ -0,0 +1,2146 @@ +package contract + +/* +#include +#include +#include +#include +#include "db_msg.h" +#include "db_module.h" + +#define ERR_BF_TIMEOUT "contract timeout" + +struct proof { + void *data; + size_t len; +}; + +#define RLP_TSTRING 0 +#define RLP_TLIST 1 + +struct rlp_obj { + int rlp_obj_type; + void *data; + size_t size; +}; +*/ +import "C" +import ( + "bytes" + "math/big" + "math/rand" + "crypto/sha256" + "errors" + "fmt" + "strconv" + "strings" + "unsafe" + "runtime" + + "github.com/aergoio/aergo-lib/log" + "github.com/aergoio/aergo/v2/cmd/aergoluac/util" + "github.com/aergoio/aergo/v2/contract/name" + "github.com/aergoio/aergo/v2/contract/system" + "github.com/aergoio/aergo/v2/contract/msg" + "github.com/aergoio/aergo/v2/internal/common" + "github.com/aergoio/aergo/v2/internal/enc/base58" + "github.com/aergoio/aergo/v2/internal/enc/hex" + "github.com/aergoio/aergo/v2/state" + "github.com/aergoio/aergo/v2/state/statedb" + "github.com/aergoio/aergo/v2/types" + "github.com/aergoio/aergo/v2/types/dbkey" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" +) + +var ( + mulAergo, mulGaer, zeroBig *big.Int + vmLogger = log.NewLogger("contract.vm") +) + +const ( + maxEventCntV2 = 50 + maxEventCntV4 = 128 + maxEventNameSize = 64 + maxEventArgSize = 4096 + luaCallCountDeduc = 1000 +) + +const ( + OLD_MSG = 0 + NEW_MSG = 1 +) + +func init() { + mulAergo = types.NewAmount(1, types.Aergo) + mulGaer = types.NewAmount(1, types.Gaer) + zeroBig = types.NewZeroAmount() +} + +func maxEventCnt(ctx *vmContext) int32 { + if ctx.blockInfo.ForkVersion >= 4 { + return maxEventCntV4 + } else { + return maxEventCntV2 + } +} + +func iif[T any](condition bool, trueVal, falseVal T) T { + if condition { + return trueVal + } + return falseVal +} + +//////////////////////////////////////////////////////////////////////////////// +// VM API +//////////////////////////////////////////////////////////////////////////////// + +func (ctx *vmContext) handleSetVariable(args []string) (result string, err error) { + if len(args) != 2 { + return "", errors.New("[System.SetVariable] invalid number of arguments") + } + key, value := []byte(args[0]), []byte(args[1]) + if ctx.isQuery || ctx.nestedView > 0 { + return "", errors.New("[System.SetVariable] set not permitted in query") + } + if err := ctx.curContract.callState.ctrState.SetData(key, value); err != nil { + return "", err + } + if err := ctx.addUpdateSize(int64(types.HashIDLength + len(value))); err != nil { + err = errors.New("uncatchable: " + err.Error()) + return "", err + } + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString("[Set]\n") + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("Key=%s Len=%v byte=%v\n", + string(key), len(key), key)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("Data=%s Len=%d byte=%v\n", + string(value), len(value), value)) + } + return "", nil +} + +func (ctx *vmContext) handleGetVariable(args []string) (result string, err error) { + if len(args) != 2 { + return "", errors.New("[System.GetVariable] invalid number of arguments") + } + key := []byte(args[0]) + blkno := args[1] + if len(blkno) > 0 { + bigNo, _ := new(big.Int).SetString(strings.TrimSpace(blkno), 10) + if bigNo == nil || bigNo.Sign() < 0 { + return "", errors.New("[System.GetVariable] invalid blockheight value :" + blkno) + } + blkNo := bigNo.Uint64() + + chainBlockHeight := ctx.blockInfo.No + if chainBlockHeight == 0 { + bestBlock, err := ctx.cdb.GetBestBlock() + if err != nil { + return "", errors.New("[System.GetVariable] get best block error") + } + chainBlockHeight = bestBlock.GetHeader().GetBlockNo() + } + if blkNo < chainBlockHeight { + blk, err := ctx.cdb.GetBlockByNo(blkNo) + if err != nil { + return "", err + } + accountId := types.ToAccountID(ctx.curContract.contractId) + contractProof, err := ctx.bs.GetAccountAndProof(accountId[:], blk.GetHeader().GetBlocksRootHash(), false) + if err != nil { + return "", errors.New("[System.GetVariable] failed to get snapshot state for account") + } else if contractProof.Inclusion { + trieKey := common.Hasher(key) + varProof, err := ctx.bs.GetVarAndProof(trieKey, contractProof.GetState().GetStorageRoot(), false) + if err != nil { + return "", errors.New("[System.GetVariable] failed to get snapshot state variable in contract") + } + if varProof.Inclusion { + if len(varProof.GetValue()) == 0 { + return "", nil + } + return string(varProof.GetValue()), nil + } + } + return "", nil + } + } + + data, err := ctx.curContract.callState.ctrState.GetData(key) + if err != nil { + return "", err + } + if data == nil { + return "", nil + } + return string(data), nil +} + +func (ctx *vmContext) handleDelVariable(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[System.DelVariable] invalid number of arguments") + } + key := []byte(args[0]) + if ctx.isQuery || ctx.nestedView > 0 { + return "", errors.New("[System.DelVariable] delete not permitted in query") + } + if err := ctx.curContract.callState.ctrState.DeleteData(key); err != nil { + return "", err + } + if err := ctx.addUpdateSize(int64(32)); err != nil { + err = errors.New("uncatchable: " + err.Error()) + return "", err + } + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString("[Del]\n") + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("Key=%s Len=%v byte=%v\n", + string(key), len(key), key)) + } + return "", nil +} + + +/* +func (ctx *vmContext) setInstCount(parent *LState, child *LState) { + if !ctx.IsGasSystem() { + C.vm_setinstcount(parent, C.vm_instcount(child)) + } +} + +func (ctx *vmContext) setInstMinusCount(L *LState, deduc C.int) { + if !ctx.IsGasSystem() { + C.vm_setinstcount(L, ctx.minusCallCount(C.vm_instcount(L), deduc)) + } +} + +func (ctx *vmContext) minusCallCount(curCount, deduc C.int) C.int { + if ctx.IsGasSystem() { + return 0 + } + remain := curCount - deduc + if remain <= 0 { + remain = 1 + } + return remain +} +*/ + + +func (ctx *vmContext) handleCall(args []string) (result string, err error) { + if len(args) != 5 { + return "", errors.New("[Contract.Call] invalid number of arguments") + } + contractAddress, fname, fargs, amount, gas := args[0], args[1], args[2], args[3], args[4] + // gas => remaining gas + // but it can also be the gas limit set by the caller contract + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + // get the contract address + cid, err := getAddressNameResolved(contractAddress, ctx.bs) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] invalid contractId: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] invalid contractId: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + aid := types.ToAccountID(cid) + + // read the amount for the contract call + amountBig, err := transformAmount(amount, ctx.blockInfo.ForkVersion) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] invalid amount: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] invalid amount: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the contract state + cs, err := getContractState(ctx, cid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] getAccount error: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] getAccount error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // check if the contract exists + bytecode := getContractCode(cs.ctrState, ctx.bs) + if bytecode == nil { + errmsg[NEW_MSG] = "[Contract.Call] cannot find contract " + contractAddress + errmsg[OLD_MSG] = "[Contract.LuaCallContract] cannot find contract " + contractAddress + return "", errors.New(errmsg[errnum]) + } + + // read the arguments for the contract call + var ci types.CallInfo + ci.Name = fname + err = getCallInfo(&ci.Args, []byte(fargs), cid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] invalid arguments: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] invalid arguments: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the remaining gas or gas limit from the parent contract + gasLimit, err := ctx.parseGasLimit(gas) + if err != nil { + return "", err + } + + // create a new executor + ce := newExecutor(bytecode, cid, ctx, &ci, amountBig, false, false, cs.ctrState) + defer ce.close() // close the executor and the VM instance + if ce.err != nil { + errmsg[NEW_MSG] = "[Contract.Call] newExecutor error: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] newExecutor error: " + return "", errors.New(errmsg[errnum] + ce.err.Error()) + } + + // set the remaining gas or gas limit from the parent contract + ce.contractGasLimit = gasLimit + + // send the amount to the contract + senderState := ctx.curContract.callState.accState + receiverState := cs.accState + if amountBig.Cmp(zeroBig) > 0 { + if ctx.isQuery == true || ctx.nestedView > 0 { + errmsg[NEW_MSG] = "[Contract.Call] send not permitted in query" + errmsg[OLD_MSG] = "[Contract.LuaCallContract] send not permitted in query" + return "", errors.New(errmsg[errnum]) + } + if r := sendBalance(senderState, receiverState, amountBig); r != nil { + errmsg[NEW_MSG] = "[Contract.Call] " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] " + return "", errors.New(errmsg[errnum] + r.Error()) + } + } + + seq, err := setRecoveryPoint(aid, ctx, senderState, cs, amountBig, false, false) + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[CALL Contract %v(%v) %v]\n", + contractAddress, aid.String(), fname)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("SendBalance: %s\n", amountBig.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", + senderState.Balance().String(), receiverState.Balance().String())) + } + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] database error: " + errmsg[OLD_MSG] = "[System.LuaCallContract] database error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // set the current contract info + prevContract := ctx.curContract + ctx.curContract = newContractInfo(cs, prevContract.contractId, cid, receiverState.RP(), amountBig) + defer func() { + ctx.curContract = prevContract + }() + + // execute the contract call + ce.call(false) + + // the result contains the used gas in the first 8 bytes + result = ce.jsonRet + + // check if the contract call failed + if ce.err != nil { + // revert the contract to the previous state + err := clearRecovery(ctx, seq, true) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] recovery err: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] recovery err: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) + } + // in case of timeout, return the original error message + switch ceErr := ce.err.(type) { + case *VmTimeoutError: + return result, errors.New(ceErr.Error()) + default: + errmsg[NEW_MSG] = "[Contract.Call] call err: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] call err: " + return "", errors.New(errmsg[errnum] + ceErr.Error()) + } + } + + // release the recovery point + if seq == 1 { + err := clearRecovery(ctx, seq, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Call] recovery err: " + errmsg[OLD_MSG] = "[Contract.LuaCallContract] recovery err: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + + // return the result + return result, nil +} + +func (ctx *vmContext) handleDelegateCall(args []string) (result string, err error) { + if len(args) != 4 { + return "", errors.New("[Contract.DelegateCall] invalid number of arguments") + } + contractAddress, fname, fargs, gas := args[0], args[1], args[2], args[3] + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + var isMultiCall bool + var cid []byte + + // get the contract address + if contractAddress == "multicall" { + isMultiCall = true + fargs = fname + fname = "execute" + cid = ctx.curContract.contractId + } else { + cid, err = getAddressNameResolved(contractAddress, ctx.bs) + if err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] invalid contractId: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] invalid contractId: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + aid := types.ToAccountID(cid) + + // get the contract state + var contractState *statedb.ContractState + if isMultiCall { + contractState = statedb.GetMultiCallState(cid, ctx.curContract.callState.ctrState.State) + } else { + contractState, err = getOnlyContractState(ctx, cid) + } + if err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] getContractState error: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract]getContractState error" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the contract code + var bytecode []byte + if isMultiCall { + bytecode = getMultiCallContractCode(contractState) + } else { + bytecode = getContractCode(contractState, ctx.bs) + } + if bytecode == nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] cannot find contract " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] cannot find contract " + return "", errors.New(errmsg[errnum] + contractAddress) + } + + // read the arguments for the contract call + var ci types.CallInfo + if isMultiCall { + err = getMultiCallInfo(&ci, []byte(fargs)) + } else { + ci.Name = fname + err = getCallInfo(&ci.Args, []byte(fargs), cid) + } + if err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] invalid arguments: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] invalid arguments: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the remaining gas or gas limit from the parent contract + gasLimit, err := ctx.parseGasLimit(gas) + if err != nil { + return "", err + } + + // create a new executor + ce := newExecutor(bytecode, cid, ctx, &ci, zeroBig, false, false, contractState) + defer ce.close() // close the executor and the VM instance + if ce.err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] newExecutor error: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] newExecutor error: " + return "", errors.New(errmsg[errnum] + ce.err.Error()) + } + + // set the remaining gas or gas limit from the parent contract + ce.contractGasLimit = gasLimit + + seq, err := setRecoveryPoint(aid, ctx, nil, ctx.curContract.callState, zeroBig, false, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] database error: " + errmsg[OLD_MSG] = "[System.LuaDelegateCallContract] database error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[DELEGATECALL Contract %v %v]\n", contractAddress, fname)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) + } + + // execute the contract call + ce.call(false) + + // the result contains the used gas in the first 8 bytes + result = ce.jsonRet + + // check if the contract call failed + if ce.err != nil { + // revert the contract to the previous state + err := clearRecovery(ctx, seq, true) + if err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] recovery error: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] recovery error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) + } + // in case of timeout, return the original error message + switch ceErr := ce.err.(type) { + case *VmTimeoutError: + return result, errors.New(ceErr.Error()) + default: + errmsg[NEW_MSG] = "[Contract.DelegateCall] call error: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] call error: " + return "", errors.New(errmsg[errnum] + ce.err.Error()) + } + } + + // release the recovery point + if seq == 1 { + err := clearRecovery(ctx, seq, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.DelegateCall] recovery error: " + errmsg[OLD_MSG] = "[Contract.LuaDelegateCallContract] recovery error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + + // return the result + return result, nil +} + +func getAddressNameResolved(account string, bs *state.BlockState) ([]byte, error) { + accountLen := len(account) + if accountLen == types.EncodedAddressLength { + return types.DecodeAddress(account) + } else if accountLen == types.NameLength { + cid, err := name.Resolve(bs, []byte(account), false) + if err != nil { + return nil, err + } + if cid == nil { + return nil, errors.New("name not founded :" + account) + } + return cid, nil + } + return nil, errors.New("invalid account length:" + account) +} + +func (ctx *vmContext) handleSend(args []string) (result string, err error) { + if len(args) != 3 { + return "", errors.New("[Contract.Send] invalid number of arguments") + } + contractAddress, amount, gas := args[0], args[1], args[2] + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + // read the amount to be sent + amountBig, err := transformAmount(amount, ctx.blockInfo.ForkVersion) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] invalid amount: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] invalid amount: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // cannot send amount in query + if (ctx.isQuery == true || ctx.nestedView > 0) && amountBig.Cmp(zeroBig) > 0 { + errmsg[NEW_MSG] = "[Contract.Send] send not permitted in query" + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] send not permitted in query" + return "", errors.New(errmsg[errnum]) + } + + // get the receiver account + cid, err := getAddressNameResolved(contractAddress, ctx.bs) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] invalid contractId: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] invalid contractId: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the receiver state + aid := types.ToAccountID(cid) + cs, err := getCallState(ctx, cid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] getAccount error: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] getAccount error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the sender state + senderState := ctx.curContract.callState.accState + receiverState := cs.accState + + // check if the receiver is a contract + if len(receiverState.CodeHash()) > 0 { + + // get the contract state + if cs.ctrState == nil { + cs.ctrState, err = statedb.OpenContractState(cid, receiverState.State(), ctx.bs.StateDB) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] getContractState error: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] getContractState error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + + // set the function to be called + var ci types.CallInfo + ci.Name = "default" + + // get the contract code + bytecode := getContractCode(cs.ctrState, ctx.bs) + if bytecode == nil { + errmsg[NEW_MSG] = "[Contract.Send] cannot find contract:" + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] cannot find contract:" + return "", errors.New(errmsg[errnum] + contractAddress) + } + + // get the remaining gas or gas limit from the parent contract + gasLimit, err := ctx.parseGasLimit(gas) + if err != nil { + return "", err + } + + // create a new executor + ce := newExecutor(bytecode, cid, ctx, &ci, amountBig, false, false, cs.ctrState) + defer ce.close() // close the executor and the VM instance + if ce.err != nil { + errmsg[NEW_MSG] = "[Contract.Send] newExecutor error: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] newExecutor error: " + return "", errors.New(errmsg[errnum] + ce.err.Error()) + } + + // set the remaining gas or gas limit from the parent contract + ce.contractGasLimit = gasLimit + + // send the amount to the contract + if amountBig.Cmp(zeroBig) > 0 { + if r := sendBalance(senderState, receiverState, amountBig); r != nil { + errmsg[NEW_MSG] = "[Contract.Send] " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] " + return "", errors.New(errmsg[errnum] + r.Error()) + } + } + + // create a recovery point + seq, err := setRecoveryPoint(aid, ctx, senderState, cs, amountBig, false, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] database error: " + errmsg[OLD_MSG] = "[System.LuaSendAmount] database error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString( + fmt.Sprintf("[Send Call default] %s(%s) : %s\n", types.EncodeAddress(cid), aid.String(), amountBig.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", + senderState.Balance().String(), receiverState.Balance().String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) + } + + // set the current contract info + prevContract := ctx.curContract + ctx.curContract = newContractInfo(cs, prevContract.contractId, cid, receiverState.RP(), amountBig) + defer func() { + ctx.curContract = prevContract + }() + + // execute the contract call + ce.call(false) + + // the result contains the used gas in the first 8 bytes + result = ce.jsonRet + + // check if the contract call failed + if ce.err != nil { + // revert the contract to the previous state + err := clearRecovery(ctx, seq, true) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] recovery err: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] recovery err: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) + } + // in case of timeout, return the original error message + switch ceErr := ce.err.(type) { + case *VmTimeoutError: + return result, errors.New(ceErr.Error()) + default: + errmsg[NEW_MSG] = "[Contract.Send] call err: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] call err: " + return "", errors.New(errmsg[errnum] + ce.err.Error()) + } + } + + // release the recovery point + if seq == 1 { + err := clearRecovery(ctx, seq, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Send] recovery err: " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] recovery err: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + + // the transfer and contract call succeeded + return result, nil + } + + // the receiver is not a contract, just send the amount + + // if amount is zero, do nothing + if amountBig.Cmp(zeroBig) == 0 { + return result, nil + } + + // send the amount to the receiver + if r := sendBalance(senderState, receiverState, amountBig); r != nil { + errmsg[NEW_MSG] = "[Contract.Send] " + errmsg[OLD_MSG] = "[Contract.LuaSendAmount] " + return "", errors.New(errmsg[errnum] + r.Error()) + } + + // update the recovery point + if ctx.lastRecoveryEntry != nil { + _, _ = setRecoveryPoint(aid, ctx, senderState, cs, amountBig, true, false) + } + + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[Send] %s(%s) : %s\n", + types.EncodeAddress(cid), aid.String(), amountBig.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", + senderState.Balance().String(), receiverState.Balance().String())) + } + + return result, nil +} + +func (ctx *vmContext) handlePrint(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.Print] invalid number of arguments") + } + ctrLgr.Info().Str("Contract SystemPrint", types.EncodeAddress(ctx.curContract.contractId)).Msg(args[0]) + return "", nil +} + +func (ctx *vmContext) handleSetRecoveryPoint() (result string, err error) { + if ctx.isQuery || ctx.nestedView > 0 { + return "", nil + } + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + curContract := ctx.curContract + // if it is the multicall code, ignore + if curContract.callState.ctrState.IsMultiCall() { + return "", nil + } + aid := types.ToAccountID(curContract.contractId) + seq, err := setRecoveryPoint(aid, ctx, nil, curContract.callState, zeroBig, false, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.SetRecoveryPoint] database error: " + errmsg[OLD_MSG] = "[Contract.pcall] database error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[pcall] snapshot set %d\n", seq)) + } + return strconv.Itoa(seq), nil +} + +func clearRecovery(ctx *vmContext, start int, revert bool) error { + item := ctx.lastRecoveryEntry + for { + if revert { + if item.revertState(ctx) != nil { + return errors.New("database error") + } + } + if item.seq == start { + if revert || item.prev == nil { + ctx.lastRecoveryEntry = item.prev + } + return nil + } + item = item.prev + if item == nil { + return errors.New("internal error") + } + } +} + +func (ctx *vmContext) handleClearRecovery(args []string) (result string, err error) { + if len(args) != 2 { + return "", errors.New("[Contract.ClearRecovery] invalid number of arguments") + } + start, err := strconv.Atoi(args[0]) + if err != nil { + return "", errors.New("[Contract.ClearRecovery] invalid start") + } + revert, err := strconv.ParseBool(args[1]) + if err != nil { + return "", errors.New("[Contract.ClearRecovery] invalid revert") + } + err = clearRecovery(ctx, start, revert) + if err != nil { + return "", err + } + if ctx.traceFile != nil && revert == true { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("pcall recovery snapshot: %d\n", start)) + } + return "", nil +} + +func (ctx *vmContext) handleGetBalance(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.GetBalance] invalid number of arguments") + } + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + contractAddress := args[0] + if contractAddress == "" { + return ctx.curContract.callState.ctrState.GetBalanceBigInt().String(), nil + } + cid, err := getAddressNameResolved(contractAddress, ctx.bs) + if err != nil { + errmsg[NEW_MSG] = "[Contract.GetBalance] invalid contractId: " + errmsg[OLD_MSG] = "[Contract.LuaGetBalance] invalid contractId: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + aid := types.ToAccountID(cid) + cs := ctx.callState[aid] + if cs == nil { + as, err := ctx.bs.GetAccountState(aid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.GetBalance] getAccount error: " + errmsg[OLD_MSG] = "[Contract.LuaGetBalance] getAccount error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + return as.GetBalanceBigInt().String(), nil + } + return cs.accState.Balance().String(), nil +} + + + +func (ctx *vmContext) getContractId() string { + return types.EncodeAddress(ctx.curContract.contractId) +} + +func (ctx *vmContext) getSender() string { + return types.EncodeAddress(ctx.curContract.sender) +} + +func (ctx *vmContext) getAmount() string { + return ctx.curContract.amount.String() +} + +func (ctx *vmContext) getTxHash() string { + return base58.Encode(ctx.txHash) +} + +func (ctx *vmContext) getOrigin() string { + return types.EncodeAddress(ctx.origin) +} + +func (ctx *vmContext) getIsFeeDelegation() bool { + return ctx.isFeeDelegation +} + +func (ctx *vmContext) getBlockNo() uint64 { + return ctx.blockInfo.No +} + +func (ctx *vmContext) getPrevBlockHash() string { + return base58.Encode(ctx.blockInfo.PrevBlockHash) +} + +func (ctx *vmContext) getTimestamp() uint64 { + return uint64(ctx.blockInfo.Ts / 1e9) +} + + + +func (ctx *vmContext) handleGetContractId() (result string, err error) { + //setInstMinusCount(ctx, L, 1000) + return types.EncodeAddress(ctx.curContract.contractId), nil +} + +func (ctx *vmContext) handleGetSender() (result string, err error) { + //setInstMinusCount(ctx, L, 1000) + return types.EncodeAddress(ctx.curContract.sender), nil +} + +func (ctx *vmContext) handleGetAmount() (result string, err error) { + return ctx.curContract.amount.String(), nil +} + +func (ctx *vmContext) handleGetTxHash() (result string, err error) { + return base58.Encode(ctx.txHash), nil +} + +func (ctx *vmContext) handleGetOrigin() (result string, err error) { + //setInstMinusCount(ctx, L, 1000) + return types.EncodeAddress(ctx.origin), nil +} + +func (ctx *vmContext) handleIsFeeDelegation() (result string, err error) { + if ctx.isFeeDelegation { + return "1", nil + } + return "0", nil +} + +func (ctx *vmContext) handleGetBlockNo() (result string, err error) { + return strconv.Itoa(int(ctx.blockInfo.No)), nil +} + +func (ctx *vmContext) handleGetPrevBlockHash() (result string, err error) { + return base58.Encode(ctx.blockInfo.PrevBlockHash), nil +} + +func (ctx *vmContext) handleGetTimeStamp() (result string, err error) { + return strconv.FormatInt(ctx.blockInfo.Ts / 1e9, 10), nil +} + + +//export checkDbExecContext +func checkDbExecContext(service C.int) bool { + // check if service is valid + if service < 0 || service >= C.int(len(contexts)) { + return false + } + if PubNet { + return false + } + return true +} + +//export luaGetDbHandle +func luaGetDbHandle(service C.int) *C.sqlite3 { + ctx := contexts[service] + curContract := ctx.curContract + cs := curContract.callState + if cs.tx != nil { + return cs.tx.getHandle() + } + var tx sqlTx + var err error + + // make sure that this go routine does not migrate to another thread + runtime.LockOSThread() + + aid := types.ToAccountID(curContract.contractId) + if ctx.isQuery == true { + tx, err = beginReadOnly(aid.String(), curContract.rp) + } else { + tx, err = beginTx(aid.String(), curContract.rp) + } + if err != nil { + sqlLgr.Error().Err(err).Msg("Begin SQL Transaction") + return nil + } + if ctx.isQuery == false { + err = tx.savepoint() + if err != nil { + sqlLgr.Error().Err(err).Msg("Begin SQL Transaction") + return nil + } + } + cs.tx = tx + return cs.tx.getHandle() +} + + +func checkHexString(data string) bool { + if len(data) >= 2 && data[0] == '0' && (data[1] == 'x' || data[1] == 'X') { + return true + } + return false +} + +func (ctx *vmContext) handleCryptoSha256(args []string) (string, error) { + if len(args) != 1 { + return "", fmt.Errorf("[Contract.CryptoSha256] invalid number of arguments") + } + data := []byte(args[0]) + if checkHexString(string(data)) { + dataStr := data[2:] + var err error + data, err = hex.Decode(string(dataStr)) + if err != nil { + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + errmsg[NEW_MSG] = "[Contract.CryptoSha256] hex decoding error: " + errmsg[OLD_MSG] = "[Contract.LuaCryptoSha256] hex decoding error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + h := sha256.New() + h.Write(data) + resultHash := h.Sum(nil) + return "0x" + hex.Encode(resultHash), nil +} + +func decodeHex(hexStr string) ([]byte, error) { + if checkHexString(hexStr) { + hexStr = hexStr[2:] + } + return hex.Decode(hexStr) +} + +func (ctx *vmContext) handleECVerify(args []string) (result string, err error) { + if len(args) != 3 { + return "", errors.New("[Contract.EcVerify] invalid number of arguments") + } + msg, sig, addr := args[0], args[1], args[2] + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + bMsg, err := decodeHex(msg) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] invalid message format: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] invalid message format: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + bSig, err := decodeHex(sig) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] invalid signature format: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] invalid signature format: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + var pubKey *btcec.PublicKey + var verifyResult bool + isAergo := len(addr) == types.EncodedAddressLength + + /*Aergo Address*/ + if isAergo { + bAddress, err := types.DecodeAddress(addr) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] invalid aergo address: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] invalid aergo address: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + pubKey, err = btcec.ParsePubKey(bAddress) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] error parsing pubKey: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] error parsing pubKey: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + + // CompactSign + if len(bSig) == 65 { + // ethereum + if !isAergo { + btcsig := make([]byte, 65) + btcsig[0] = bSig[64] + 27 + copy(btcsig[1:], bSig) + bSig = btcsig + } + pub, _, err := ecdsa.RecoverCompact(bSig, bMsg) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] error recoverCompact: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] error recoverCompact: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + if pubKey != nil { + verifyResult = pubKey.IsEqual(pub) + } else { + bAddress, err := decodeHex(addr) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] invalid Ethereum address: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] invalid Ethereum address: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + bPub := pub.SerializeUncompressed() + h := sha256.New() + h.Write(bPub[1:]) + signAddress := h.Sum(nil)[12:] + verifyResult = bytes.Equal(bAddress, signAddress) + } + } else { + sign, err := ecdsa.ParseSignature(bSig) + if err != nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] error parsing signature: " + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] error parsing signature: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + if pubKey == nil { + errmsg[NEW_MSG] = "[Contract.EcVerify] error recovering pubKey" + errmsg[OLD_MSG] = "[Contract.LuaEcVerify] error recovering pubKey" + return "", errors.New(errmsg[errnum]) + } + verifyResult = sign.Verify(bMsg, pubKey) + } + if verifyResult { + return "1", nil + } + return "0", nil +} + +func luaCryptoToBytes(data []byte) ([]byte, bool) { + var d []byte + isHex := checkHexString(string(data)) + if isHex { + var err error + d, err = hex.Decode(string(data[2:])) + if err != nil { + isHex = false + } + } + if !isHex { + d = data + } + return d, isHex +} + +func cryptoBytesToRlpObject(data []byte) rlpObject { + // read the first byte to determine the type of the RLP object + rlpType := data[0] + data = data[1:] + // convert the remaining bytes to the appropriate type + if rlpType == C.RLP_TSTRING { + return rlpString(data) + } + // if the type is not a list, return nil + if rlpType != C.RLP_TLIST { + return nil + } + // the type is a list. deserialize it + items, err := msg.DeserializeMessage(data) + if err != nil { + return nil + } + // convert the items to rlpList + list := make(rlpList, len(items)) + for i, item := range items { + list[i] = rlpString(item) + } + return list +} + +func (ctx *vmContext) handleCryptoVerifyEthStorageProof(args []string) (result string, err error) { + if len(args) != 4 { + return "", errors.New("[Contract.CryptoVerifyEthStorageProof] invalid number of arguments") + } + key := []byte(args[0]) + value := cryptoBytesToRlpObject([]byte(args[1])) + hash := []byte(args[2]) + proof, err := msg.DeserializeMessage([]byte(args[3])) + if err != nil { + return "", errors.New("[Contract.CryptoVerifyEthStorageProof] error deserializing proof: " + err.Error()) + } + proofBytes := make([][]byte, len(proof)) + for i, p := range proof { + proofBytes[i] = []byte(p) + } + + if verifyEthStorageProof(key, value, hash, proofBytes) { + return "1", nil + } + return "0", nil +} + +func (ctx *vmContext) handleCryptoKeccak256(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.CryptoKeccak256] invalid number of arguments") + } + data, isHex := luaCryptoToBytes([]byte(args[0])) + h := keccak256(data) + if isHex { + hexb := "0x" + hex.Encode(h) + return hexb, nil + } else { + return string(h), nil + } +} + +// transformAmount processes the input string to calculate the total amount, +// taking into account the different units ("aergo", "gaer", "aer") +func transformAmount(amountStr string, forkVersion int32) (*big.Int, error) { + if len(amountStr) == 0 { + return zeroBig, nil + } + + if forkVersion >= 4 { + // Check for amount in decimal format + if strings.Contains(amountStr,".") && strings.HasSuffix(strings.ToLower(amountStr),"aergo") { + // Extract the part before the unit + decimalAmount := amountStr[:len(amountStr)-5] + decimalAmount = strings.TrimRight(decimalAmount, " ") + // Parse the decimal amount + decimalAmount = parseDecimalAmount(decimalAmount, 18) + if decimalAmount == "error" { + return nil, errors.New("converting error for BigNum: " + amountStr) + } + amount, valid := new(big.Int).SetString(decimalAmount, 10) + if !valid { + return nil, errors.New("converting error for BigNum: " + amountStr) + } + return amount, nil + } + } + + totalAmount := new(big.Int) + remainingStr := amountStr + + // Define the units and corresponding multipliers + for _, data := range []struct { + unit string + multiplier *big.Int + }{ + {"aergo", mulAergo}, + {"gaer", mulGaer}, + {"aer", zeroBig}, + } { + idx := strings.Index(strings.ToLower(remainingStr), data.unit) + if idx != -1 { + // Extract the part before the unit + subStr := remainingStr[:idx] + + // Parse and convert the amount + partialAmount, err := parseAndConvert(subStr, data.unit, data.multiplier, amountStr) + if err != nil { + return nil, err + } + + // Add to the total amount + totalAmount.Add(totalAmount, partialAmount) + + // Adjust the remaining string to process + remainingStr = remainingStr[idx+len(data.unit):] + } + } + + // Process the rest of the string, if there is some + if len(remainingStr) > 0 { + partialAmount, err := parseAndConvert(remainingStr, "", zeroBig, amountStr) + if err != nil { + return nil, err + } + + // Add to the total amount + totalAmount.Add(totalAmount, partialAmount) + } + + return totalAmount, nil +} + +// convert decimal amount into big integer string +func parseDecimalAmount(str string, num_decimals int) string { + // Get the integer and decimal parts + idx := strings.Index(str, ".") + if idx == -1 { + return str + } + p1 := str[0:idx] + p2 := str[idx+1:] + + // Check for another decimal point + if strings.Index(p2, ".") != -1 { + return "error" + } + + // Compute the amount of zero digits to add + to_add := num_decimals - len(p2) + if to_add > 0 { + p2 = p2 + strings.Repeat("0", to_add) + } else if to_add < 0 { + // Do not truncate decimal amounts + return "error" + } + + // Join the integer and decimal parts + str = p1 + p2 + + // Remove leading zeros + str = strings.TrimLeft(str, "0") + if str == "" { + str = "0" + } + return str +} + +// parseAndConvert is a helper function to parse the substring as a big integer +// and apply the necessary multiplier based on the unit. +func parseAndConvert(subStr, unit string, mulUnit *big.Int, fullStr string) (*big.Int, error) { + subStr = strings.TrimSpace(subStr) + + // Convert the string to a big integer + amountBig, valid := new(big.Int).SetString(subStr, 10) + if !valid { + // Emits a backwards compatible error message + // the same as: dataType := len(unit) > 0 ? "BigNum" : "Integer" + dataType := map[bool]string{true: "BigNum", false: "Integer"}[len(unit) > 0] + return nil, errors.New("converting error for " + dataType + ": " + strings.TrimSpace(fullStr)) + } + + // Check for negative amounts + if amountBig.Cmp(zeroBig) < 0 { + return nil, errors.New("negative amount not allowed") + } + + // Apply multiplier based on unit + if mulUnit != zeroBig { + amountBig.Mul(amountBig, mulUnit) + } + + return amountBig, nil +} + +func (ctx *vmContext) handleDeploy(args []string) (result string, err error) { + if len(args) != 4 { + return "", errors.New("[Contract.Deploy] invalid number of arguments") + } + codeOrAddress, fargs, amount, gas := args[0], args[1], args[2], args[3] + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + if ctx.isQuery || ctx.nestedView > 0 { + errmsg[NEW_MSG] = "[Contract.Deploy] deploy not permitted in view" + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]send not permitted in query" + return "", errors.New(errmsg[errnum]) + } + bs := ctx.bs + + // contract code + var codeABI []byte + var sourceCode []byte + + // check if contract name or address is given + cid, err := getAddressNameResolved(codeOrAddress, bs) + if err == nil { + // check if contract exists + contractState, err := getOnlyContractState(ctx, cid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]" + return "", errors.New(errmsg[errnum] + err.Error()) + } + // read the contract code + codeABI, err = contractState.GetCode() + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]" + return "", errors.New(errmsg[errnum] + err.Error()) + } else if len(codeABI) == 0 { + errmsg[NEW_MSG] = "[Contract.Deploy] not found code" + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]: not found code" + return "", errors.New(errmsg[errnum]) + } + if ctx.blockInfo.ForkVersion >= 4 { + sourceCode = contractState.GetSourceCode() + } + } + + //! maybe not needed on hardfork 5, if using Lua for new contracts + // but it could at least check the code for validity + + // compile contract code if not found + if len(codeABI) == 0 { + codeABI, err = Compile(codeOrAddress, true) + if err != nil { + // check if string contains timeout error + if strings.Contains(err.Error(), C.ERR_BF_TIMEOUT) { + return "", err //errors.New(C.ERR_BF_TIMEOUT) + } else if err == ErrVmStart { + errmsg[NEW_MSG] = "[Contract.Deploy] get luaState error" + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]get luaState error" + return "", errors.New(errmsg[errnum]) + } + errmsg[NEW_MSG] = "[Contract.Deploy] compile error: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]compile error:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + if ctx.blockInfo.ForkVersion >= 4 { + sourceCode = []byte(codeOrAddress) + } + } + + err = ctx.addUpdateSize(int64(len(codeABI) + len(sourceCode))) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // create account for the contract + creator := ctx.curContract.callState.accState + newContract, err := state.CreateAccountState(CreateContractID(ctx.curContract.contractId, creator.Nonce()), bs.StateDB) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + contractState, err := statedb.OpenContractState(newContract.ID(), newContract.State(), bs.StateDB) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + cs := &callState{isCallback: true, isDeploy: true, ctrState: contractState, accState: newContract} + ctx.callState[newContract.AccountID()] = cs + + // read the amount transferred to the contract + amountBig, err := transformAmount(amount, ctx.blockInfo.ForkVersion) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] value not proper format: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]value not proper format:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // read the arguments for the constructor call + var ci types.CallInfo + err = getCallInfo(&ci.Args, []byte(fargs), newContract.ID()) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] invalid args: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]invalid args:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // send the amount to the contract + senderState := ctx.curContract.callState.accState + receiverState := cs.accState + if amountBig.Cmp(zeroBig) > 0 { + if rv := sendBalance(senderState, receiverState, amountBig); rv != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]" + return "", errors.New(errmsg[errnum] + rv.Error()) + } + } + + // create a recovery point + seq, err := setRecoveryPoint(newContract.AccountID(), ctx, senderState, cs, amountBig, false, true) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] DB err: " + errmsg[OLD_MSG] = "[System.LuaDeployContract] DB err:" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[DEPLOY] %s(%s)\n", + types.EncodeAddress(newContract.ID()), newContract.AccountID().String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("deploy snapshot set %d\n", seq)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("SendBalance : %s\n", amountBig.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", + senderState.Balance().String(), receiverState.Balance().String())) + } + + // set the contract info + prevContract := ctx.curContract + ctx.curContract = newContractInfo(cs, prevContract.contractId, newContract.ID(), receiverState.RP(), amountBig) + defer func() { + ctx.curContract = prevContract + }() + + bytecode := util.LuaCode(codeABI).ByteCode() + + // save the contract code + err = contractState.SetCode(sourceCode, codeABI) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // save the contract creator + err = contractState.SetData(dbkey.CreatorMeta(), []byte(types.EncodeAddress(prevContract.contractId))) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]" + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // get the remaining gas or gas limit from the parent contract + gasLimit, err := ctx.parseGasLimit(gas) + if err != nil { + return "", err + } + + // create a new executor + ce := newExecutor(bytecode, newContract.ID(), ctx, &ci, amountBig, true, false, contractState) + defer ce.close() // close the executor and the VM instance + if ce.err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] newExecutor Error: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract]newExecutor Error :" + return "", errors.New(errmsg[errnum] + ce.err.Error()) + } + + // set the remaining gas or gas limit from the parent contract + ce.contractGasLimit = gasLimit + + // increment the nonce of the creator + senderState.SetNonce(senderState.Nonce() + 1) + + addr := types.EncodeAddress(newContract.ID()) + + if ce != nil { + // run the constructor + ce.call(false) + + // the result contains the used gas in the first 8 bytes + result = ce.jsonRet + + // check if the execution was successful + if ce.err != nil { + // revert the contract to the previous state + err := clearRecovery(ctx, seq, true) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] recovery error: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract] recovery error: " + return result, errors.New(errmsg[errnum] + err.Error()) + } + // log some info + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) + } + // in case of timeout, return the original error message + switch ceErr := ce.err.(type) { + case *VmTimeoutError: + return result, errors.New(ceErr.Error()) + default: + errmsg[NEW_MSG] = "[Contract.Deploy] call err: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract] call err:" + return result, errors.New(errmsg[errnum] + ce.err.Error()) + } + } + } + + // release the recovery point + if seq == 1 { + err := clearRecovery(ctx, seq, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Deploy] recovery error: " + errmsg[OLD_MSG] = "[Contract.LuaDeployContract] recovery error: " + return result, errors.New(errmsg[errnum] + err.Error()) + } + } + + // the result already contains a JSON array + // insert the contract address before the other returned values + // the first 8 bytes contain the used gas + result = result[:8] + `["` + addr + `",` + result[9:] + + return result, nil +} + +func setRandomSeed(ctx *vmContext) { + var randSrc rand.Source + if ctx.isQuery { + randSrc = rand.NewSource(ctx.blockInfo.Ts) + } else { + b, _ := new(big.Int).SetString(base58.Encode(ctx.blockInfo.PrevBlockHash[:7]), 62) + t, _ := new(big.Int).SetString(base58.Encode(ctx.txHash[:7]), 62) + b.Add(b, t) + randSrc = rand.NewSource(b.Int64()) + } + ctx.seed = rand.New(randSrc) +} + +func (ctx *vmContext) handleRandomInt(args []string) (result string, err error) { + if len(args) != 2 { + return "", errors.New("[Contract.RandomInt] invalid number of arguments") + } + min, err := strconv.Atoi(args[0]) + if err != nil { + return "", errors.New("[Contract.RandomInt] invalid min") + } + max, err := strconv.Atoi(args[1]) + if err != nil { + return "", errors.New("[Contract.RandomInt] invalid max") + } + if ctx.seed == nil { + setRandomSeed(ctx) + } + return strconv.Itoa(ctx.seed.Intn(max+1-min) + min), nil +} + +func (ctx *vmContext) handleEvent(args []string) (result string, err error) { + if len(args) != 2 { + return "", errors.New("[Contract.Event] invalid number of arguments") + } + eventName, eventArgs := args[0], args[1] + if ctx.isQuery || ctx.nestedView > 0 { + return "", errors.New("[Contract.Event] event not permitted in query") + } + if ctx.eventCount >= maxEventCnt(ctx) { + return "", errors.New(fmt.Sprintf("[Contract.Event] exceeded the maximum number of events(%d)", maxEventCnt(ctx))) + } + if len(eventName) > maxEventNameSize { + return "", errors.New(fmt.Sprintf("[Contract.Event] exceeded the maximum length of event name(%d)", maxEventNameSize)) + } + if len(eventArgs) > maxEventArgSize { + return "", errors.New(fmt.Sprintf("[Contract.Event] exceeded the maximum length of event args(%d)", maxEventArgSize)) + } + ctx.events = append( + ctx.events, + &types.Event{ + ContractAddress: ctx.curContract.contractId, + EventIdx: ctx.eventCount, + EventName: eventName, + JsonArgs: eventArgs, + }, + ) + ctx.eventCount++ + return "", nil +} + +func (ctx *vmContext) handleToPubkey(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.ToPubkey] invalid number of arguments") + } + address := args[0] + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + // check the length of address + if len(address) != types.EncodedAddressLength { + errmsg[NEW_MSG] = "[Contract.ToPubkey] invalid address length" + errmsg[OLD_MSG] = "[Contract.LuaToPubkey] invalid address length" + return "", errors.New(errmsg[errnum]) + } + // decode the address in string format to bytes (public key) + pubkey, err := types.DecodeAddress(address) + if err != nil { + errmsg[NEW_MSG] = "[Contract.ToPubkey] invalid address" + errmsg[OLD_MSG] = "[Contract.LuaToPubkey] invalid address" + return "", errors.New(errmsg[errnum]) + } + // return the public key in hex format + return "0x" + hex.Encode(pubkey), nil +} + +func (ctx *vmContext) handleToAddress(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.ToAddress] invalid number of arguments") + } + pubkey := args[0] + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + // decode the pubkey in hex format to bytes + pubkeyBytes, err := decodeHex(pubkey) + if err != nil { + errmsg[NEW_MSG] = "[Contract.ToAddress] invalid public key" + errmsg[OLD_MSG] = "[Contract.LuaToAddress] invalid public key" + return "", errors.New(errmsg[errnum]) + } + // check the length of pubkey + if len(pubkeyBytes) != types.AddressLength { + errmsg[NEW_MSG] = "[Contract.ToAddress] invalid public key length" + errmsg[OLD_MSG] = "[Contract.LuaToAddress] invalid public key length" + return "", errors.New(errmsg[errnum]) + // or convert the pubkey to compact format - SerializeCompressed() + } + // encode the pubkey in bytes to an address in string format + address := types.EncodeAddress(pubkeyBytes) + // return the address + return address, nil +} + +func (ctx *vmContext) handleIsContract(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.IsContract] invalid number of arguments") + } + contractAddress := args[0] + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + cid, err := getAddressNameResolved(contractAddress, ctx.bs) + if err != nil { + errmsg[NEW_MSG] = "[Contract.IsContract] invalid contractId: " + errmsg[OLD_MSG] = "[Contract.LuaIsContract] invalid contractId: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + cs, err := getCallState(ctx, cid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.IsContract] getAccount error: " + errmsg[OLD_MSG] = "[Contract.LuaIsContract] getAccount error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + return strconv.Itoa(len(cs.accState.CodeHash())), nil +} + +func (ctx *vmContext) handleNameResolve(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.NameResolve] invalid number of arguments") + } + account := args[0] // account name or address + var addr []byte + if len(account) == types.EncodedAddressLength { + // also checks if valid address + addr, err = types.DecodeAddress(account) + } else { + addr, err = name.Resolve(ctx.bs, []byte(account), false) + } + if err != nil { + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + errmsg[NEW_MSG] = "[Contract.NameResolve] " + errmsg[OLD_MSG] = "[Contract.LuaNameResolve] " + return "", errors.New(errmsg[errnum] + err.Error()) + } + return types.EncodeAddress(addr), nil +} + +func (ctx *vmContext) handleGovernance(args []string) (result string, err error) { + if len(args) != 2 { + return "", errors.New("[Contract.Governance] invalid number of arguments") + } + gType, arg := args[0], args[1] + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + if ctx.isQuery || ctx.nestedView > 0 { + errmsg[NEW_MSG] = "[Contract.Governance] governance not permitted in query" + errmsg[OLD_MSG] = "[Contract.LuaGovernance] governance not permitted in query" + return "", errors.New(errmsg[errnum]) + } + + var amountBig *big.Int + var payload []byte + + switch gType { + case "S", "U": + var err error + amountBig, err = transformAmount(arg, ctx.blockInfo.ForkVersion) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Governance] invalid amount: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] invalid amount: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + if gType == "S" { + payload = []byte(fmt.Sprintf(`{"Name":"%s"}`, types.Opstake.Cmd())) + } else { + payload = []byte(fmt.Sprintf(`{"Name":"%s"}`, types.Opunstake.Cmd())) + } + case "V": + amountBig = zeroBig + payload = []byte(fmt.Sprintf(`{"Name":"%s","Args":%s}`, types.OpvoteBP.Cmd(), arg)) + case "D": + amountBig = zeroBig + payload = []byte(fmt.Sprintf(`{"Name":"%s","Args":%s}`, types.OpvoteDAO.Cmd(), arg)) + } + + cid := []byte(types.AergoSystem) + aid := types.ToAccountID(cid) + scsState, err := getContractState(ctx, cid) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Governance] getAccount error: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] getAccount error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + curContract := ctx.curContract + + senderState := curContract.callState.accState + receiverState := scsState.accState + + txBody := types.TxBody{ + Amount: amountBig.Bytes(), + Payload: payload, + } + if ctx.blockInfo.ForkVersion >= 2 { + txBody.Account = curContract.contractId + } + + err = types.ValidateSystemTx(&txBody) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Governance] error: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // create a recovery point + seq, err := setRecoveryPoint(aid, ctx, senderState, scsState, zeroBig, false, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Governance] database error: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] database error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // execute the system transaction + events, err := system.ExecuteSystemTx(scsState.ctrState, &txBody, senderState, receiverState, ctx.blockInfo) + if err != nil { + // revert the contract to the previous state + rErr := clearRecovery(ctx, seq, true) + if rErr != nil { + errmsg[NEW_MSG] = "[Contract.Governance] recovery error: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] recovery error: " + return "", errors.New(errmsg[errnum] + rErr.Error()) + } + errmsg[NEW_MSG] = "[Contract.Governance] error: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + + // release the recovery point + if seq == 1 { + err := clearRecovery(ctx, seq, false) + if err != nil { + errmsg[NEW_MSG] = "[Contract.Governance] recovery error: " + errmsg[OLD_MSG] = "[Contract.LuaGovernance] recovery error: " + return "", errors.New(errmsg[errnum] + err.Error()) + } + } + + // add the events to the context + ctx.eventCount += int32(len(events)) + ctx.events = append(ctx.events, events...) + + if ctx.lastRecoveryEntry != nil { + if gType == "S" { + seq, _ = setRecoveryPoint(aid, ctx, senderState, scsState, amountBig, true, false) + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[GOVERNANCE]aid(%s)\n", aid.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("staking : %s\n", amountBig.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", + senderState.Balance().String(), receiverState.Balance().String())) + } + } else if gType == "U" { + seq, _ = setRecoveryPoint(aid, ctx, receiverState, ctx.curContract.callState, amountBig, true, false) + if ctx.traceFile != nil { + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[GOVERNANCE]aid(%s)\n", aid.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("unstaking : %s\n", amountBig.String())) + _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", + senderState.Balance().String(), receiverState.Balance().String())) + } + } + } + + return "", nil +} + + +//////////////////////////////////////////////////////////////////////////////// + +func (ctx *vmContext) handleDbExec(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.Exec] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_db_exec(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +func (ctx *vmContext) handleDbQuery(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.Query] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_db_query(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +func (ctx *vmContext) handleDbPrepare(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.Prepare] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_db_prepare(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +//stmtExec +func (ctx *vmContext) handleStmtExec(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.StmtExec] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_stmt_exec(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +//stmtQuery +func (ctx *vmContext) handleStmtQuery(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.StmtQuery] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_stmt_query(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +//stmtColumnInfo +func (ctx *vmContext) handleStmtColumnInfo(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.StmtColumnInfo] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_stmt_column_info(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +//rsNext +func (ctx *vmContext) handleRsNext(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.RsNext] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_rs_next(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +//rsGet +func (ctx *vmContext) handleRsGet(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.RsGet] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_rs_get(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +/* +//rsColumnInfo +func (ctx *vmContext) handleRsColumnInfo(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.RsColumnInfo] invalid number of arguments") + } + col_id, err := strconv.Atoi(args[0]) + if err != nil { + return "", errors.New("[DB.RsColumnInfo] invalid column id") + } + + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_rs_column_info(&cReq, C.int(col_id)) + return processResult(&cReq) +} +*/ + +/* +//rsClose +func (ctx *vmContext) handleRsClose(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.RsClose] invalid number of arguments") + } + query_id, err := strconv.Atoi(args[0]) + if err != nil { + return "", errors.New("[DB.RsClose] invalid query id") + } + + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_rs_close(&cReq, C.int(query_id)) + return processResult(&cReq) +} +*/ + +//lastInsertRowid +func (ctx *vmContext) handleLastInsertRowid(args []string) (result string, err error) { + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_last_insert_rowid(&cReq) + return processResult(&cReq) +} + +//dbOpenWithSnapshot +func (ctx *vmContext) handleDbOpenWithSnapshot(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[DB.DbOpenWithSnapshot] invalid number of arguments") + } + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_db_open_with_snapshot(&cReq, (*C.char)(unsafe.Pointer(&[]byte(args[0])[0])), C.int(len(args[0]))) + return processResult(&cReq) +} + +//dbGetSnapshot +func (ctx *vmContext) handleDbGetSnapshot(args []string) (result string, err error) { + var cReq C.request + cReq.service = C.int(ctx.service) + C.handle_db_get_snapshot(&cReq) + return processResult(&cReq) +} + +func processResult(cReq *C.request) (result string, err error) { + if cReq.result.ptr != nil { + result = C.GoStringN(cReq.result.ptr, cReq.result.len) + C.free(unsafe.Pointer(cReq.result.ptr)) + } + if cReq.error != nil { + errstr := C.GoString(cReq.error) + C.free(unsafe.Pointer(cReq.error)) + err = errors.New(errstr) + } + return result, err +} + + + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + + +//export isPublic +func isPublic() C.int { + if PubNet { + return C.int(1) + } else { + return C.int(0) + } +} + + + + + +// this is only used at server side, by db_module.c + +//export luaIsView +func luaIsView(service C.int) C.bool { + ctx := contexts[service] + return C.bool(ctx.nestedView > 0) +} + + + + +// checks whether the block creation timeout occurred +// +func checkTimeout(service int) bool { + + // only check timeout for the block factory + if service != BlockFactory { + return false + } + + ctx := contexts[service] + select { + case <-ctx.execCtx.Done(): + return true + default: + return false + } + +} + +//export LuaGetDbHandleSnap +func LuaGetDbHandleSnap(service C.int, snapshot *C.char) *C.char { + ctx := contexts[service] + + curContract := ctx.curContract + callState := curContract.callState + + errmsg := [2]string{} + errnum := iif(CurrentForkVersion >= 5, NEW_MSG, OLD_MSG) + + if ctx.isQuery != true { + errmsg[NEW_MSG] = "[Contract.SetDbSnap] not permitted in transaction" + errmsg[OLD_MSG] = "[Contract.LuaSetDbSnap] not permitted in transaction" + return C.CString(errmsg[errnum]) + } + + if callState.tx != nil { + errmsg[NEW_MSG] = "[Contract.SetDbSnap] transaction already started" + errmsg[OLD_MSG] = "[Contract.LuaSetDbSnap] transaction already started" + return C.CString(errmsg[errnum]) + } + + rp, err := strconv.ParseUint(C.GoString(snapshot), 10, 64) + if err != nil { + errmsg[NEW_MSG] = "[Contract.SetDbSnap] snapshot is not valid: " + errmsg[OLD_MSG] = "[Contract.LuaSetDbSnap] snapshot is not valid" + return C.CString(errmsg[errnum] + C.GoString(snapshot)) + } + + aid := types.ToAccountID(curContract.contractId) + tx, err := beginReadOnly(aid.String(), rp) + if err != nil { + errmsg[NEW_MSG] = "[Contract.SetDbSnap] Error Begin SQL Transaction" + errmsg[OLD_MSG] = "Error Begin SQL Transaction" + return C.CString(errmsg[errnum]) + } + + callState.tx = tx + return nil +} + +//export LuaGetDbSnapshot +func LuaGetDbSnapshot(service C.int) *C.char { + ctx := contexts[service] + return C.CString(strconv.FormatUint(ctx.curContract.rp, 10)) +} + + + + +func (ctx *vmContext) handleGetStaking(args []string) (result string, err error) { + if len(args) != 1 { + return "", errors.New("[Contract.GetStaking] invalid number of arguments") + } + addr := args[0] + + systemcs, err := statedb.GetSystemAccountState(ctx.bs.StateDB) + if err != nil { + return "", err + } + + namecs, err := statedb.GetNameAccountState(ctx.bs.StateDB) + if err != nil { + return "", err + } + + staking, err := system.GetStaking(systemcs, name.GetAddress(namecs, types.ToAddress(addr))) + if err != nil { + return "", err + } + + // returns a string with the amount and when + result = staking.GetAmountBigInt().String() + "," + strconv.FormatUint(staking.When, 10) + return result, nil +} + + +//////////////////////////////////////////////////////////////////////////////// + +func sendBalance(sender *state.AccountState, receiver *state.AccountState, amount *big.Int) error { + if err := state.SendBalance(sender, receiver, amount); err != nil { + if CurrentForkVersion >= 5 { + return errors.New("insufficient balance: " + sender.Balance().String() + + " - amount to transfer: " + amount.String()) + } else { + return errors.New("[Contract.sendBalance] insufficient balance: " + + sender.Balance().String() + " : " + amount.String()) + } + } + return nil +} diff --git a/contract/vm_callback_test.go b/contract/vm_api_test.go similarity index 100% rename from contract/vm_callback_test.go rename to contract/vm_api_test.go diff --git a/contract/vm_callback.go b/contract/vm_callback.go deleted file mode 100644 index 0cadc5b32..000000000 --- a/contract/vm_callback.go +++ /dev/null @@ -1,1701 +0,0 @@ -package contract - -/* -#cgo CFLAGS: -I${SRCDIR}/../libtool/include/luajit-2.1 -#cgo LDFLAGS: ${SRCDIR}/../libtool/lib/libluajit-5.1.a -lm - -#include -#include -#include "vm.h" -#include "bignum_module.h" - -struct proof { - void *data; - size_t len; -}; - -#define RLP_TSTRING 0 -#define RLP_TLIST 1 - -struct rlp_obj { - int rlp_obj_type; - void *data; - size_t size; -}; -*/ -import "C" -import ( - "bytes" - "crypto/sha256" - "errors" - "fmt" - "math/big" - "strconv" - "strings" - "unsafe" - - "github.com/aergoio/aergo-lib/log" - "github.com/aergoio/aergo/v2/cmd/aergoluac/util" - "github.com/aergoio/aergo/v2/contract/name" - "github.com/aergoio/aergo/v2/contract/system" - "github.com/aergoio/aergo/v2/internal/common" - "github.com/aergoio/aergo/v2/internal/enc/base58" - "github.com/aergoio/aergo/v2/internal/enc/hex" - "github.com/aergoio/aergo/v2/state" - "github.com/aergoio/aergo/v2/state/statedb" - "github.com/aergoio/aergo/v2/types" - "github.com/aergoio/aergo/v2/types/dbkey" - "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcec/v2/ecdsa" -) - -var ( - mulAergo, mulGaer, zeroBig *big.Int - vmLogger = log.NewLogger("contract.vm") -) - -const ( - maxEventCntV2 = 50 - maxEventCntV4 = 128 - maxEventNameSize = 64 - maxEventArgSize = 4096 - luaCallCountDeduc = 1000 -) - -func init() { - mulAergo = types.NewAmount(1, types.Aergo) - mulGaer = types.NewAmount(1, types.Gaer) - zeroBig = types.NewZeroAmount() -} - -func maxEventCnt(ctx *vmContext) int32 { - if ctx.blockInfo.ForkVersion >= 4 { - return maxEventCntV4 - } else { - return maxEventCntV2 - } -} - -//export luaSetDB -func luaSetDB(L *LState, service C.int, key unsafe.Pointer, keyLen C.int, value *C.char) *C.char { - ctx := contexts[service] - if ctx == nil { - return C.CString("[System.LuaSetDB] contract state not found") - } - if ctx.isQuery == true || ctx.nestedView > 0 { - return C.CString("[System.LuaSetDB] set not permitted in query") - } - val := []byte(C.GoString(value)) - if err := ctx.curContract.callState.ctrState.SetData(C.GoBytes(key, keyLen), val); err != nil { - return C.CString(err.Error()) - } - if err := ctx.addUpdateSize(int64(types.HashIDLength + len(val))); err != nil { - C.luaL_setuncatchablerror(L) - return C.CString(err.Error()) - } - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString("[Set]\n") - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("Key=%s Len=%v byte=%v\n", - string(C.GoBytes(key, keyLen)), keyLen, C.GoBytes(key, keyLen))) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("Data=%s Len=%d byte=%v\n", - string(val), len(val), val)) - } - return nil -} - -//export luaGetDB -func luaGetDB(L *LState, service C.int, key unsafe.Pointer, keyLen C.int, blkno *C.char) (*C.char, *C.char) { - ctx := contexts[service] - if ctx == nil { - return nil, C.CString("[System.LuaGetDB] contract state not found") - } - if blkno != nil { - bigNo, _ := new(big.Int).SetString(strings.TrimSpace(C.GoString(blkno)), 10) - if bigNo == nil || bigNo.Sign() < 0 { - return nil, C.CString("[System.LuaGetDB] invalid blockheight value :" + C.GoString(blkno)) - } - blkNo := bigNo.Uint64() - - chainBlockHeight := ctx.blockInfo.No - if chainBlockHeight == 0 { - bestBlock, err := ctx.cdb.GetBestBlock() - if err != nil { - return nil, C.CString("[System.LuaGetDB] get best block error") - } - chainBlockHeight = bestBlock.GetHeader().GetBlockNo() - } - if blkNo < chainBlockHeight { - blk, err := ctx.cdb.GetBlockByNo(blkNo) - if err != nil { - return nil, C.CString(err.Error()) - } - accountId := types.ToAccountID(ctx.curContract.contractId) - contractProof, err := ctx.bs.GetAccountAndProof(accountId[:], blk.GetHeader().GetBlocksRootHash(), false) - if err != nil { - return nil, C.CString("[System.LuaGetDB] failed to get snapshot state for account") - } else if contractProof.Inclusion { - trieKey := common.Hasher(C.GoBytes(key, keyLen)) - varProof, err := ctx.bs.GetVarAndProof(trieKey, contractProof.GetState().GetStorageRoot(), false) - if err != nil { - return nil, C.CString("[System.LuaGetDB] failed to get snapshot state variable in contract") - } - if varProof.Inclusion { - if len(varProof.GetValue()) == 0 { - return nil, nil - } - return C.CString(string(varProof.GetValue())), nil - } - } - return nil, nil - } - } - - data, err := ctx.curContract.callState.ctrState.GetData(C.GoBytes(key, keyLen)) - if err != nil { - return nil, C.CString(err.Error()) - } - if data == nil { - return nil, nil - } - return C.CString(string(data)), nil -} - -//export luaDelDB -func luaDelDB(L *LState, service C.int, key unsafe.Pointer, keyLen C.int) *C.char { - ctx := contexts[service] - if ctx == nil { - return C.CString("[System.LuaDelDB] contract state not found") - } - if ctx.isQuery == true || ctx.nestedView > 0 { - return C.CString("[System.LuaDelDB] delete not permitted in query") - } - if err := ctx.curContract.callState.ctrState.DeleteData(C.GoBytes(key, keyLen)); err != nil { - return C.CString(err.Error()) - } - if err := ctx.addUpdateSize(int64(32)); err != nil { - C.luaL_setuncatchablerror(L) - return C.CString(err.Error()) - } - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString("[Del]\n") - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("Key=%s Len=%v byte=%v\n", - string(C.GoBytes(key, keyLen)), keyLen, C.GoBytes(key, keyLen))) - } - return nil -} - -func setInstCount(ctx *vmContext, parent *LState, child *LState) { - if !ctx.IsGasSystem() { - C.vm_setinstcount(parent, C.vm_instcount(child)) - } -} - -func setInstMinusCount(ctx *vmContext, L *LState, deduc C.int) { - if !ctx.IsGasSystem() { - C.vm_setinstcount(L, minusCallCount(ctx, C.vm_instcount(L), deduc)) - } -} - -func minusCallCount(ctx *vmContext, curCount, deduc C.int) C.int { - if ctx.IsGasSystem() { - return 0 - } - remain := curCount - deduc - if remain <= 0 { - remain = 1 - } - return remain -} - -//export luaCallContract -func luaCallContract(L *LState, service C.int, contractId *C.char, fname *C.char, args *C.char, - amount *C.char, gas uint64) (C.int, *C.char) { - fnameStr := C.GoString(fname) - argsStr := C.GoString(args) - - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.LuaCallContract] contract state not found") - } - - // get the contract address - contractAddress := C.GoString(contractId) - cid, err := getAddressNameResolved(contractAddress, ctx.bs) - if err != nil { - return -1, C.CString("[Contract.LuaCallContract] invalid contractId: " + err.Error()) - } - aid := types.ToAccountID(cid) - - // read the amount for the contract call - amountBig, err := transformAmount(C.GoString(amount), ctx.blockInfo.ForkVersion) - if err != nil { - return -1, C.CString("[Contract.LuaCallContract] invalid amount: " + err.Error()) - } - - // get the contract state - cs, err := getContractState(ctx, cid) - if err != nil { - return -1, C.CString("[Contract.LuaCallContract] getAccount error: " + err.Error()) - } - - // check if the contract exists - bytecode := getContractCode(cs.ctrState, ctx.bs) - if bytecode == nil { - return -1, C.CString("[Contract.LuaCallContract] cannot find contract " + C.GoString(contractId)) - } - - prevContractInfo := ctx.curContract - - // read the arguments for the contract call - var ci types.CallInfo - ci.Name = fnameStr - err = getCallInfo(&ci.Args, []byte(argsStr), cid) - if err != nil { - return -1, C.CString("[Contract.LuaCallContract] invalid arguments: " + err.Error()) - } - - // get the remaining gas from the parent LState - ctx.refreshRemainingGas(L) - // create a new executor with the remaining gas on the child LState - ce := newExecutor(bytecode, cid, ctx, &ci, amountBig, false, false, cs.ctrState) - defer func() { - // close the executor, closes also the child LState - ce.close() - // set the remaining gas on the parent LState - ctx.setRemainingGas(L) - }() - - if ce.err != nil { - return -1, C.CString("[Contract.LuaCallContract] newExecutor error: " + ce.err.Error()) - } - - // send the amount to the contract - senderState := prevContractInfo.callState.accState - receiverState := cs.accState - if amountBig.Cmp(zeroBig) > 0 { - if ctx.isQuery == true || ctx.nestedView > 0 { - return -1, C.CString("[Contract.LuaCallContract] send not permitted in query") - } - if r := sendBalance(senderState, receiverState, amountBig); r != nil { - return -1, r - } - } - - seq, err := setRecoveryPoint(aid, ctx, senderState, cs, amountBig, false, false) - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[CALL Contract %v(%v) %v]\n", - contractAddress, aid.String(), fnameStr)) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("SendBalance: %s\n", amountBig.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", - senderState.Balance().String(), receiverState.Balance().String())) - } - if err != nil { - return -1, C.CString("[System.LuaCallContract] database error: " + err.Error()) - } - - // set the current contract info - ctx.curContract = newContractInfo(cs, prevContractInfo.contractId, cid, - receiverState.RP(), amountBig) - defer func() { - ctx.curContract = prevContractInfo - }() - - // execute the contract call - defer setInstCount(ctx, L, ce.L) - ret := ce.call(minusCallCount(ctx, C.vm_instcount(L), luaCallCountDeduc), L) - - // check if the contract call failed - if ce.err != nil { - err := clearRecovery(L, ctx, seq, true) - if err != nil { - return -1, C.CString("[Contract.LuaCallContract] recovery err: " + err.Error()) - } - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) - } - switch ceErr := ce.err.(type) { - case *VmTimeoutError: - return -1, C.CString(ceErr.Error()) - default: - return -1, C.CString("[Contract.LuaCallContract] call err: " + ceErr.Error()) - - } - } - - if seq == 1 { - err := clearRecovery(L, ctx, seq, false) - if err != nil { - return -1, C.CString("[Contract.LuaCallContract] recovery err: " + err.Error()) - } - } - - return ret, nil -} - -//export luaDelegateCallContract -func luaDelegateCallContract(L *LState, service C.int, contractId *C.char, - fname *C.char, args *C.char, gas uint64) (C.int, *C.char) { - contractIdStr := C.GoString(contractId) - fnameStr := C.GoString(fname) - argsStr := C.GoString(args) - - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] contract state not found") - } - - var isMultiCall bool - var cid []byte - var err error - - // get the contract address - if contractIdStr == "multicall" { - isMultiCall = true - argsStr = fnameStr - fnameStr = "execute" - cid = ctx.curContract.contractId - } else { - cid, err = getAddressNameResolved(contractIdStr, ctx.bs) - if err != nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] invalid contractId: " + err.Error()) - } - } - aid := types.ToAccountID(cid) - - // get the contract state - var contractState *statedb.ContractState - if isMultiCall { - contractState = statedb.GetMultiCallState(cid, ctx.curContract.callState.ctrState.State) - } else { - contractState, err = getOnlyContractState(ctx, cid) - } - if err != nil { - return -1, C.CString("[Contract.LuaDelegateCallContract]getContractState error" + err.Error()) - } - - // get the contract code - var bytecode []byte - if isMultiCall { - bytecode = getMultiCallContractCode(contractState) - } else { - bytecode = getContractCode(contractState, ctx.bs) - } - if bytecode == nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] cannot find contract " + contractIdStr) - } - - // read the arguments for the contract call - var ci types.CallInfo - if isMultiCall { - err = getMultiCallInfo(&ci, []byte(argsStr)) - } else { - ci.Name = fnameStr - err = getCallInfo(&ci.Args, []byte(argsStr), cid) - } - if err != nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] invalid arguments: " + err.Error()) - } - - // get the remaining gas from the parent LState - ctx.refreshRemainingGas(L) - // create a new executor with the remaining gas on the child LState - ce := newExecutor(bytecode, cid, ctx, &ci, zeroBig, false, false, contractState) - defer func() { - // close the executor, closes also the child LState - ce.close() - // set the remaining gas on the parent LState - ctx.setRemainingGas(L) - }() - - if ce.err != nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] newExecutor error: " + ce.err.Error()) - } - - seq, err := setRecoveryPoint(aid, ctx, nil, ctx.curContract.callState, zeroBig, false, false) - if err != nil { - return -1, C.CString("[System.LuaDelegateCallContract] database error: " + err.Error()) - } - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[DELEGATECALL Contract %v %v]\n", contractIdStr, fnameStr)) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) - } - - // execute the contract call - defer setInstCount(ctx, L, ce.L) - ret := ce.call(minusCallCount(ctx, C.vm_instcount(L), luaCallCountDeduc), L) - - // check if the contract call failed - if ce.err != nil { - err := clearRecovery(L, ctx, seq, true) - if err != nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] recovery error: " + err.Error()) - } - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) - } - switch ceErr := ce.err.(type) { - case *VmTimeoutError: - return -1, C.CString(ceErr.Error()) - default: - return -1, C.CString("[Contract.LuaDelegateCallContract] call error: " + ce.err.Error()) - } - } - - if seq == 1 { - err := clearRecovery(L, ctx, seq, false) - if err != nil { - return -1, C.CString("[Contract.LuaDelegateCallContract] recovery error: " + err.Error()) - } - } - - return ret, nil -} - -func getAddressNameResolved(account string, bs *state.BlockState) ([]byte, error) { - accountLen := len(account) - if accountLen == types.EncodedAddressLength { - return types.DecodeAddress(account) - } else if accountLen == types.NameLength { - cid, err := name.Resolve(bs, []byte(account), false) - if err != nil { - return nil, err - } - if cid == nil { - return nil, errors.New("name not founded :" + account) - } - return cid, nil - } - return nil, errors.New("invalid account length:" + account) -} - -//export luaSendAmount -func luaSendAmount(L *LState, service C.int, contractId *C.char, amount *C.char) *C.char { - - ctx := contexts[service] - if ctx == nil { - return C.CString("[Contract.LuaSendAmount] contract state not found") - } - - // read the amount to be sent - amountBig, err := transformAmount(C.GoString(amount), ctx.blockInfo.ForkVersion) - if err != nil { - return C.CString("[Contract.LuaSendAmount] invalid amount: " + err.Error()) - } - - // cannot send amount in query - if (ctx.isQuery == true || ctx.nestedView > 0) && amountBig.Cmp(zeroBig) > 0 { - return C.CString("[Contract.LuaSendAmount] send not permitted in query") - } - - // get the receiver account - cid, err := getAddressNameResolved(C.GoString(contractId), ctx.bs) - if err != nil { - return C.CString("[Contract.LuaSendAmount] invalid contractId: " + err.Error()) - } - - // get the receiver state - aid := types.ToAccountID(cid) - cs, err := getCallState(ctx, cid) - if err != nil { - return C.CString("[Contract.LuaSendAmount] getAccount error: " + err.Error()) - } - - // get the sender state - senderState := ctx.curContract.callState.accState - receiverState := cs.accState - - // check if the receiver is a contract - if len(receiverState.CodeHash()) > 0 { - - // get the contract state - if cs.ctrState == nil { - cs.ctrState, err = statedb.OpenContractState(cid, receiverState.State(), ctx.bs.StateDB) - if err != nil { - return C.CString("[Contract.LuaSendAmount] getContractState error: " + err.Error()) - } - } - - // set the function to be called - var ci types.CallInfo - ci.Name = "default" - - // get the contract code - bytecode := getContractCode(cs.ctrState, ctx.bs) - if bytecode == nil { - return C.CString("[Contract.LuaSendAmount] cannot find contract:" + C.GoString(contractId)) - } - - // get the remaining gas from the parent LState - ctx.refreshRemainingGas(L) - // create a new executor with the remaining gas on the child LState - ce := newExecutor(bytecode, cid, ctx, &ci, amountBig, false, false, cs.ctrState) - defer func() { - // close the executor, closes also the child LState - ce.close() - // set the remaining gas on the parent LState - ctx.setRemainingGas(L) - }() - - if ce.err != nil { - return C.CString("[Contract.LuaSendAmount] newExecutor error: " + ce.err.Error()) - } - - // send the amount to the contract - if amountBig.Cmp(zeroBig) > 0 { - if r := sendBalance(senderState, receiverState, amountBig); r != nil { - return r - } - } - - // create a recovery point - seq, err := setRecoveryPoint(aid, ctx, senderState, cs, amountBig, false, false) - if err != nil { - return C.CString("[System.LuaSendAmount] database error: " + err.Error()) - } - - // log some info - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString( - fmt.Sprintf("[Send Call default] %s(%s) : %s\n", types.EncodeAddress(cid), aid.String(), amountBig.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", - senderState.Balance().String(), receiverState.Balance().String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) - } - - // set the current contract info - prevContractInfo := ctx.curContract - ctx.curContract = newContractInfo(cs, prevContractInfo.contractId, cid, - receiverState.RP(), amountBig) - defer func() { - ctx.curContract = prevContractInfo - }() - - // execute the contract call - defer setInstCount(ctx, L, ce.L) - ce.call(minusCallCount(ctx, C.vm_instcount(L), luaCallCountDeduc), L) - - // check if the contract call failed - if ce.err != nil { - // recover to the previous state - err := clearRecovery(L, ctx, seq, true) - if err != nil { - return C.CString("[Contract.LuaSendAmount] recovery err: " + err.Error()) - } - // log some info - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) - } - // return the error message - return C.CString("[Contract.LuaSendAmount] call err: " + ce.err.Error()) - } - - if seq == 1 { - err := clearRecovery(L, ctx, seq, false) - if err != nil { - return C.CString("[Contract.LuaSendAmount] recovery err: " + err.Error()) - } - } - - // the transfer and contract call succeeded - return nil - } - - // the receiver is not a contract, just send the amount - - // if amount is zero, do nothing - if amountBig.Cmp(zeroBig) == 0 { - return nil - } - - // send the amount to the receiver - if r := sendBalance(senderState, receiverState, amountBig); r != nil { - return r - } - - // update the recovery point - if ctx.lastRecoveryEntry != nil { - _, _ = setRecoveryPoint(aid, ctx, senderState, cs, amountBig, true, false) - } - - // log some info - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[Send] %s(%s) : %s\n", - types.EncodeAddress(cid), aid.String(), amountBig.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", - senderState.Balance().String(), receiverState.Balance().String())) - } - - return nil -} - -//export luaPrint -func luaPrint(L *LState, service C.int, args *C.char) { - ctx := contexts[service] - setInstMinusCount(ctx, L, 1000) - ctrLgr.Info().Str("Contract SystemPrint", types.EncodeAddress(ctx.curContract.contractId)).Msg(C.GoString(args)) -} - -//export luaSetRecoveryPoint -func luaSetRecoveryPoint(L *LState, service C.int) (C.int, *C.char) { - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.pcall] contract state not found") - } - if ctx.isQuery == true || ctx.nestedView > 0 { - return 0, nil - } - curContract := ctx.curContract - // if it is the multicall code, ignore - if curContract.callState.ctrState.IsMultiCall() { - return 0, nil - } - seq, err := setRecoveryPoint(types.ToAccountID(curContract.contractId), ctx, nil, - curContract.callState, zeroBig, false, false) - if err != nil { - return -1, C.CString("[Contract.pcall] database error: " + err.Error()) - } - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[Pcall] snapshot set %d\n", seq)) - } - return C.int(seq), nil -} - -func clearRecovery(L *LState, ctx *vmContext, start int, isError bool) error { - item := ctx.lastRecoveryEntry - for { - if isError { - if item.recovery(ctx.bs) != nil { - return errors.New("database error") - } - } - if item.seq == start { - if isError || item.prev == nil { - ctx.lastRecoveryEntry = item.prev - } - return nil - } - item = item.prev - if item == nil { - return errors.New("internal error") - } - } -} - -//export luaClearRecovery -func luaClearRecovery(L *LState, service C.int, start int, isError bool) *C.char { - ctx := contexts[service] - if ctx == nil { - return C.CString("[Contract.pcall] contract state not found") - } - err := clearRecovery(L, ctx, start, isError) - if err != nil { - return C.CString(err.Error()) - } - if ctx.traceFile != nil && isError == true { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("pcall recovery snapshot : %d\n", start)) - } - return nil -} - -//export luaGetBalance -func luaGetBalance(L *LState, service C.int, contractId *C.char) (*C.char, *C.char) { - ctx := contexts[service] - if contractId == nil { - return C.CString(ctx.curContract.callState.ctrState.GetBalanceBigInt().String()), nil - } - cid, err := getAddressNameResolved(C.GoString(contractId), ctx.bs) - if err != nil { - return nil, C.CString("[Contract.LuaGetBalance] invalid contractId: " + err.Error()) - } - aid := types.ToAccountID(cid) - cs := ctx.callState[aid] - if cs == nil { - bs := ctx.bs - as, err := bs.GetAccountState(aid) - if err != nil { - return nil, C.CString("[Contract.LuaGetBalance] getAccount error: " + err.Error()) - } - return C.CString(as.GetBalanceBigInt().String()), nil - } - return C.CString(cs.accState.Balance().String()), nil -} - -//export luaGetSender -func luaGetSender(L *LState, service C.int) *C.char { - ctx := contexts[service] - setInstMinusCount(ctx, L, 1000) - return C.CString(types.EncodeAddress(ctx.curContract.sender)) -} - -//export luaGetHash -func luaGetHash(L *LState, service C.int) *C.char { - ctx := contexts[service] - return C.CString(base58.Encode(ctx.txHash)) -} - -//export luaGetBlockNo -func luaGetBlockNo(L *LState, service C.int) C.lua_Integer { - ctx := contexts[service] - return C.lua_Integer(ctx.blockInfo.No) -} - -//export luaGetTimeStamp -func luaGetTimeStamp(L *LState, service C.int) C.lua_Integer { - ctx := contexts[service] - return C.lua_Integer(ctx.blockInfo.Ts / 1e9) -} - -//export luaGetContractId -func luaGetContractId(L *LState, service C.int) *C.char { - ctx := contexts[service] - setInstMinusCount(ctx, L, 1000) - return C.CString(types.EncodeAddress(ctx.curContract.contractId)) -} - -//export luaGetAmount -func luaGetAmount(L *LState, service C.int) *C.char { - ctx := contexts[service] - return C.CString(ctx.curContract.amount.String()) -} - -//export luaGetOrigin -func luaGetOrigin(L *LState, service C.int) *C.char { - ctx := contexts[service] - setInstMinusCount(ctx, L, 1000) - return C.CString(types.EncodeAddress(ctx.origin)) -} - -//export luaGetPrevBlockHash -func luaGetPrevBlockHash(L *LState, service C.int) *C.char { - ctx := contexts[service] - return C.CString(base58.Encode(ctx.blockInfo.PrevBlockHash)) -} - -//export luaGetDbHandle -func luaGetDbHandle(service C.int) *C.sqlite3 { - ctx := contexts[service] - curContract := ctx.curContract - cs := curContract.callState - if cs.tx != nil { - return cs.tx.getHandle() - } - var tx sqlTx - var err error - - aid := types.ToAccountID(curContract.contractId) - if ctx.isQuery == true { - tx, err = beginReadOnly(aid.String(), curContract.rp) - } else { - tx, err = beginTx(aid.String(), curContract.rp) - } - if err != nil { - sqlLgr.Error().Err(err).Msg("Begin SQL Transaction") - return nil - } - if ctx.isQuery == false { - err = tx.savepoint() - if err != nil { - sqlLgr.Error().Err(err).Msg("Begin SQL Transaction") - return nil - } - } - cs.tx = tx - return cs.tx.getHandle() -} - -func checkHexString(data string) bool { - if len(data) >= 2 && data[0] == '0' && (data[1] == 'x' || data[1] == 'X') { - return true - } - return false -} - -//export luaCryptoSha256 -func luaCryptoSha256(L *LState, arg unsafe.Pointer, argLen C.int) (*C.char, *C.char) { - data := C.GoBytes(arg, argLen) - if checkHexString(string(data)) { - dataStr := data[2:] - var err error - data, err = hex.Decode(string(dataStr)) - if err != nil { - return nil, C.CString("[Contract.LuaCryptoSha256] hex decoding error: " + err.Error()) - } - } - h := sha256.New() - h.Write(data) - resultHash := h.Sum(nil) - - return C.CString("0x" + hex.Encode(resultHash)), nil -} - -func decodeHex(hexStr string) ([]byte, error) { - if checkHexString(hexStr) { - hexStr = hexStr[2:] - } - return hex.Decode(hexStr) -} - -//export luaECVerify -func luaECVerify(L *LState, service C.int, msg *C.char, sig *C.char, addr *C.char) (C.int, *C.char) { - bMsg, err := decodeHex(C.GoString(msg)) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] invalid message format: " + err.Error()) - } - bSig, err := decodeHex(C.GoString(sig)) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] invalid signature format: " + err.Error()) - } - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.LuaEcVerify]not found contract state") - } - setInstMinusCount(ctx, L, 10000) - - var pubKey *btcec.PublicKey - var verifyResult bool - address := C.GoString(addr) - isAergo := len(address) == types.EncodedAddressLength - - /*Aergo Address*/ - if isAergo { - bAddress, err := types.DecodeAddress(address) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] invalid aergo address: " + err.Error()) - } - pubKey, err = btcec.ParsePubKey(bAddress) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] error parsing pubKey: " + err.Error()) - } - } - - // CompactSign - if len(bSig) == 65 { - // ethereum - if !isAergo { - btcsig := make([]byte, 65) - btcsig[0] = bSig[64] + 27 - copy(btcsig[1:], bSig) - bSig = btcsig - } - pub, _, err := ecdsa.RecoverCompact(bSig, bMsg) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] error recoverCompact: " + err.Error()) - } - if pubKey != nil { - verifyResult = pubKey.IsEqual(pub) - } else { - bAddress, err := decodeHex(address) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] invalid Ethereum address: " + err.Error()) - } - bPub := pub.SerializeUncompressed() - h := sha256.New() - h.Write(bPub[1:]) - signAddress := h.Sum(nil)[12:] - verifyResult = bytes.Equal(bAddress, signAddress) - } - } else { - sign, err := ecdsa.ParseSignature(bSig) - if err != nil { - return -1, C.CString("[Contract.LuaEcVerify] error parsing signature: " + err.Error()) - } - if pubKey == nil { - return -1, C.CString("[Contract.LuaEcVerify] error recovering pubKey") - } - verifyResult = sign.Verify(bMsg, pubKey) - } - if verifyResult { - return C.int(1), nil - } - return C.int(0), nil -} - -func luaCryptoToBytes(data unsafe.Pointer, dataLen C.int) ([]byte, bool) { - var d []byte - b := C.GoBytes(data, dataLen) - isHex := checkHexString(string(b)) - if isHex { - var err error - d, err = hex.Decode(string(b[2:])) - if err != nil { - isHex = false - } - } - if !isHex { - d = b - } - return d, isHex -} - -func luaCryptoRlpToBytes(data unsafe.Pointer) rlpObject { - x := (*C.struct_rlp_obj)(data) - if x.rlp_obj_type == C.RLP_TSTRING { - b, _ := luaCryptoToBytes(x.data, C.int(x.size)) - return rlpString(b) - } - var l rlpList - elems := (*[1 << 30]C.struct_rlp_obj)(unsafe.Pointer(x.data))[:C.int(x.size):C.int(x.size)] - for _, elem := range elems { - b, _ := luaCryptoToBytes(elem.data, C.int(elem.size)) - l = append(l, rlpString(b)) - } - return l -} - -//export luaCryptoVerifyProof -func luaCryptoVerifyProof( - key unsafe.Pointer, keyLen C.int, - value unsafe.Pointer, - hash unsafe.Pointer, hashLen C.int, - proof unsafe.Pointer, nProof C.int, -) C.int { - k, _ := luaCryptoToBytes(key, keyLen) - v := luaCryptoRlpToBytes(value) - h, _ := luaCryptoToBytes(hash, hashLen) - cProof := (*[1 << 30]C.struct_proof)(proof)[:nProof:nProof] - bProof := make([][]byte, int(nProof)) - for i, p := range cProof { - bProof[i], _ = luaCryptoToBytes(p.data, C.int(p.len)) - } - if verifyEthStorageProof(k, v, h, bProof) { - return C.int(1) - } - return C.int(0) -} - -//export luaCryptoKeccak256 -func luaCryptoKeccak256(data unsafe.Pointer, dataLen C.int) (unsafe.Pointer, int) { - d, isHex := luaCryptoToBytes(data, dataLen) - h := keccak256(d) - if isHex { - hexb := []byte("0x" + hex.Encode(h)) - return C.CBytes(hexb), len(hexb) - } else { - return C.CBytes(h), len(h) - } -} - -// transformAmount processes the input string to calculate the total amount, -// taking into account the different units ("aergo", "gaer", "aer") -func transformAmount(amountStr string, forkVersion int32) (*big.Int, error) { - if len(amountStr) == 0 { - return zeroBig, nil - } - - if forkVersion >= 4 { - // Check for amount in decimal format - if strings.Contains(amountStr,".") && strings.HasSuffix(strings.ToLower(amountStr),"aergo") { - // Extract the part before the unit - decimalAmount := amountStr[:len(amountStr)-5] - decimalAmount = strings.TrimRight(decimalAmount, " ") - // Parse the decimal amount - decimalAmount = parseDecimalAmount(decimalAmount, 18) - if decimalAmount == "error" { - return nil, errors.New("converting error for BigNum: " + amountStr) - } - amount, valid := new(big.Int).SetString(decimalAmount, 10) - if !valid { - return nil, errors.New("converting error for BigNum: " + amountStr) - } - return amount, nil - } - } - - totalAmount := new(big.Int) - remainingStr := amountStr - - // Define the units and corresponding multipliers - for _, data := range []struct { - unit string - multiplier *big.Int - }{ - {"aergo", mulAergo}, - {"gaer", mulGaer}, - {"aer", zeroBig}, - } { - idx := strings.Index(strings.ToLower(remainingStr), data.unit) - if idx != -1 { - // Extract the part before the unit - subStr := remainingStr[:idx] - - // Parse and convert the amount - partialAmount, err := parseAndConvert(subStr, data.unit, data.multiplier, amountStr) - if err != nil { - return nil, err - } - - // Add to the total amount - totalAmount.Add(totalAmount, partialAmount) - - // Adjust the remaining string to process - remainingStr = remainingStr[idx+len(data.unit):] - } - } - - // Process the rest of the string, if there is some - if len(remainingStr) > 0 { - partialAmount, err := parseAndConvert(remainingStr, "", zeroBig, amountStr) - if err != nil { - return nil, err - } - - // Add to the total amount - totalAmount.Add(totalAmount, partialAmount) - } - - return totalAmount, nil -} - -// convert decimal amount into big integer string -func parseDecimalAmount(str string, num_decimals int) string { - // Get the integer and decimal parts - idx := strings.Index(str, ".") - if idx == -1 { - return str - } - p1 := str[0:idx] - p2 := str[idx+1:] - - // Check for another decimal point - if strings.Index(p2, ".") != -1 { - return "error" - } - - // Compute the amount of zero digits to add - to_add := num_decimals - len(p2) - if to_add > 0 { - p2 = p2 + strings.Repeat("0", to_add) - } else if to_add < 0 { - // Do not truncate decimal amounts - return "error" - } - - // Join the integer and decimal parts - str = p1 + p2 - - // Remove leading zeros - str = strings.TrimLeft(str, "0") - if str == "" { - str = "0" - } - return str -} - -// parseAndConvert is a helper function to parse the substring as a big integer -// and apply the necessary multiplier based on the unit. -func parseAndConvert(subStr, unit string, mulUnit *big.Int, fullStr string) (*big.Int, error) { - subStr = strings.TrimSpace(subStr) - - // Convert the string to a big integer - amountBig, valid := new(big.Int).SetString(subStr, 10) - if !valid { - // Emits a backwards compatible error message - // the same as: dataType := len(unit) > 0 ? "BigNum" : "Integer" - dataType := map[bool]string{true: "BigNum", false: "Integer"}[len(unit) > 0] - return nil, errors.New("converting error for " + dataType + ": " + strings.TrimSpace(fullStr)) - } - - // Check for negative amounts - if amountBig.Cmp(zeroBig) < 0 { - return nil, errors.New("negative amount not allowed") - } - - // Apply multiplier based on unit - if mulUnit != zeroBig { - amountBig.Mul(amountBig, mulUnit) - } - - return amountBig, nil -} - -//export luaDeployContract -func luaDeployContract( - L *LState, - service C.int, - contract *C.char, - args *C.char, - amount *C.char, -) (C.int, *C.char) { - - argsStr := C.GoString(args) - contractStr := C.GoString(contract) - - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.LuaDeployContract]not found contract state") - } - if ctx.isQuery == true || ctx.nestedView > 0 { - return -1, C.CString("[Contract.LuaDeployContract]send not permitted in query") - } - bs := ctx.bs - - // contract code - var codeABI []byte - var sourceCode []byte - - // check if contract name or address is given - cid, err := getAddressNameResolved(contractStr, bs) - if err == nil { - // check if contract exists - contractState, err := getOnlyContractState(ctx, cid) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]" + err.Error()) - } - // read the contract code - codeABI, err = contractState.GetCode() - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]" + err.Error()) - } else if len(codeABI) == 0 { - return -1, C.CString("[Contract.LuaDeployContract]: not found code") - } - if ctx.blockInfo.ForkVersion >= 4 { - sourceCode = contractState.GetSourceCode() - } - } - - // compile contract code if not found - if len(codeABI) == 0 { - if ctx.blockInfo.ForkVersion >= 2 { - codeABI, err = Compile(contractStr, L) - } else { - codeABI, err = Compile(contractStr, nil) - } - if err != nil { - if C.luaL_hasuncatchablerror(L) != C.int(0) && - C.ERR_BF_TIMEOUT == err.Error() { - return -1, C.CString(C.ERR_BF_TIMEOUT) - } else if err == ErrVmStart { - return -1, C.CString("[Contract.LuaDeployContract] get luaState error") - } - return -1, C.CString("[Contract.LuaDeployContract]compile error:" + err.Error()) - } - if ctx.blockInfo.ForkVersion >= 4 { - sourceCode = []byte(contractStr) - } - } - - err = ctx.addUpdateSize(int64(len(codeABI) + len(sourceCode))) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]:" + err.Error()) - } - - // create account for the contract - prevContractInfo := ctx.curContract - creator := prevContractInfo.callState.accState - newContract, err := state.CreateAccountState(CreateContractID(prevContractInfo.contractId, creator.Nonce()), bs.StateDB) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]:" + err.Error()) - } - contractState, err := statedb.OpenContractState(newContract.ID(), newContract.State(), bs.StateDB) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]:" + err.Error()) - } - - cs := &callState{isCallback: true, isDeploy: true, ctrState: contractState, accState: newContract} - ctx.callState[newContract.AccountID()] = cs - - // read the amount transferred to the contract - amountBig, err := transformAmount(C.GoString(amount), ctx.blockInfo.ForkVersion) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]value not proper format:" + err.Error()) - } - - // read the arguments for the constructor call - var ci types.CallInfo - err = getCallInfo(&ci.Args, []byte(argsStr), newContract.ID()) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract] invalid args:" + err.Error()) - } - - // send the amount to the contract - senderState := prevContractInfo.callState.accState - receiverState := cs.accState - if amountBig.Cmp(zeroBig) > 0 { - if rv := sendBalance(senderState, receiverState, amountBig); rv != nil { - return -1, rv - } - } - - // create a recovery point - seq, err := setRecoveryPoint(newContract.AccountID(), ctx, senderState, cs, amountBig, false, true) - if err != nil { - return -1, C.CString("[System.LuaDeployContract] DB err:" + err.Error()) - } - - // log some info - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[DEPLOY] %s(%s)\n", - types.EncodeAddress(newContract.ID()), newContract.AccountID().String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("deploy snapshot set %d\n", seq)) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("SendBalance : %s\n", amountBig.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", - senderState.Balance().String(), receiverState.Balance().String())) - } - - // set the contract info - ctx.curContract = newContractInfo(cs, prevContractInfo.contractId, newContract.ID(), - receiverState.RP(), amountBig) - defer func() { - ctx.curContract = prevContractInfo - }() - - bytecode := util.LuaCode(codeABI).ByteCode() - - // save the contract code - err = contractState.SetCode(sourceCode, codeABI) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]:" + err.Error()) - } - - // save the contract creator - err = contractState.SetData(dbkey.CreatorMeta(), []byte(types.EncodeAddress(prevContractInfo.contractId))) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract]:" + err.Error()) - } - - // get the remaining gas from the parent LState - ctx.refreshRemainingGas(L) - // create a new executor with the remaining gas on the child LState - ce := newExecutor(bytecode, newContract.ID(), ctx, &ci, amountBig, true, false, contractState) - defer func() { - // close the executor, which will close the child LState - ce.close() - // set the remaining gas on the parent LState - ctx.setRemainingGas(L) - }() - - if ce.err != nil { - return -1, C.CString("[Contract.LuaDeployContract]newExecutor Error :" + ce.err.Error()) - } - - if ctx.blockInfo.ForkVersion < 2 { - // create a sql database for the contract - if db := luaGetDbHandle(ctx.service); db == nil { - return -1, C.CString("[System.LuaDeployContract] DB err: cannot open a database") - } - } - - // increment the nonce of the creator - senderState.SetNonce(senderState.Nonce() + 1) - - addr := C.CString(types.EncodeAddress(newContract.ID())) - ret := C.int(1) - - if ce != nil { - // run the constructor - defer setInstCount(ce.ctx, L, ce.L) - ret += ce.call(minusCallCount(ctx, C.vm_instcount(L), luaCallCountDeduc), L) - - // check if the execution was successful - if ce.err != nil { - // rollback the recovery point - err := clearRecovery(L, ctx, seq, true) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract] recovery error: " + err.Error()) - } - // log some info - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("recovery snapshot: %d\n", seq)) - } - // return the error message - return -1, C.CString("[Contract.LuaDeployContract] call err:" + ce.err.Error()) - } - } - - if seq == 1 { - err := clearRecovery(L, ctx, seq, false) - if err != nil { - return -1, C.CString("[Contract.LuaDeployContract] recovery error: " + err.Error()) - } - } - - return ret, addr -} - -//export isPublic -func isPublic() C.int { - if PubNet { - return C.int(1) - } else { - return C.int(0) - } -} - -//export luaRandomInt -func luaRandomInt(min, max, service C.int) C.int { - ctx := contexts[service] - if ctx.seed == nil { - setRandomSeed(ctx) - } - return C.int(ctx.seed.Intn(int(max+C.int(1)-min)) + int(min)) -} - -//export luaEvent -func luaEvent(L *LState, service C.int, eventName *C.char, args *C.char) *C.char { - ctx := contexts[service] - if ctx.isQuery == true || ctx.nestedView > 0 { - return C.CString("[Contract.Event] event not permitted in query") - } - if ctx.eventCount >= maxEventCnt(ctx) { - return C.CString(fmt.Sprintf("[Contract.Event] exceeded the maximum number of events(%d)", maxEventCnt(ctx))) - } - if len(C.GoString(eventName)) > maxEventNameSize { - return C.CString(fmt.Sprintf("[Contract.Event] exceeded the maximum length of event name(%d)", maxEventNameSize)) - } - if len(C.GoString(args)) > maxEventArgSize { - return C.CString(fmt.Sprintf("[Contract.Event] exceeded the maximum length of event args(%d)", maxEventArgSize)) - } - ctx.events = append( - ctx.events, - &types.Event{ - ContractAddress: ctx.curContract.contractId, - EventIdx: ctx.eventCount, - EventName: C.GoString(eventName), - JsonArgs: C.GoString(args), - }, - ) - ctx.eventCount++ - return nil -} - -//export luaGetEventCount -func luaGetEventCount(L *LState, service C.int) C.int { - eventCount := contexts[service].eventCount - if ctrLgr.IsDebugEnabled() { - ctrLgr.Debug().Int32("eventCount", eventCount).Msg("get event count") - } - return C.int(eventCount) -} - -//export luaDropEvent -func luaDropEvent(L *LState, service C.int, from C.int) { - // Drop all the events after the given index. - ctx := contexts[service] - if ctrLgr.IsDebugEnabled() { - ctrLgr.Debug().Int32("from", int32(from)).Int("len", len(ctx.events)).Msg("drop events") - } - if from >= 0 { - ctx.events = ctx.events[:from] - ctx.eventCount = int32(len(ctx.events)) - } -} - -//export luaToPubkey -func luaToPubkey(L *LState, address *C.char) *C.char { - // check the length of address - if len(C.GoString(address)) != types.EncodedAddressLength { - return C.CString("[Contract.LuaToPubkey] invalid address length") - } - // decode the address in string format to bytes (public key) - pubkey, err := types.DecodeAddress(C.GoString(address)) - if err != nil { - return C.CString("[Contract.LuaToPubkey] invalid address") - } - // return the public key in hex format - return C.CString("0x" + hex.Encode(pubkey)) -} - -//export luaToAddress -func luaToAddress(L *LState, pubkey *C.char) *C.char { - // decode the pubkey in hex format to bytes - pubkeyBytes, err := decodeHex(C.GoString(pubkey)) - if err != nil { - return C.CString("[Contract.LuaToAddress] invalid public key") - } - // check the length of pubkey - if len(pubkeyBytes) != types.AddressLength { - return C.CString("[Contract.LuaToAddress] invalid public key length") - // or convert the pubkey to compact format - SerializeCompressed() - } - // encode the pubkey in bytes to an address in string format - address := types.EncodeAddress(pubkeyBytes) - // return the address - return C.CString(address) -} - -//export luaIsContract -func luaIsContract(L *LState, service C.int, contractId *C.char) (C.int, *C.char) { - - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.LuaIsContract] contract state not found") - } - - cid, err := getAddressNameResolved(C.GoString(contractId), ctx.bs) - if err != nil { - return -1, C.CString("[Contract.LuaIsContract] invalid contractId: " + err.Error()) - } - - cs, err := getCallState(ctx, cid) - if err != nil { - return -1, C.CString("[Contract.LuaIsContract] getAccount error: " + err.Error()) - } - - return C.int(len(cs.accState.CodeHash())), nil -} - -//export luaNameResolve -func luaNameResolve(L *LState, service C.int, name_or_address *C.char) *C.char { - ctx := contexts[service] - if ctx == nil { - return C.CString("[Contract.LuaNameResolve] contract state not found") - } - var addr []byte - var err error - account := C.GoString(name_or_address) - if len(account) == types.EncodedAddressLength { - // also checks if valid address - addr, err = types.DecodeAddress(account) - } else { - addr, err = name.Resolve(ctx.bs, []byte(account), false) - } - if err != nil { - return C.CString("[Contract.LuaNameResolve] " + err.Error()) - } - return C.CString(types.EncodeAddress(addr)) -} - -//export luaGovernance -func luaGovernance(L *LState, service C.int, gType C.char, arg *C.char) *C.char { - - ctx := contexts[service] - if ctx == nil { - return C.CString("[Contract.LuaGovernance] contract state not found") - } - - if ctx.isQuery == true || ctx.nestedView > 0 { - return C.CString("[Contract.LuaGovernance] governance not permitted in query") - } - - var amountBig *big.Int - var payload []byte - - switch gType { - case 'S', 'U': - var err error - amountBig, err = transformAmount(C.GoString(arg), ctx.blockInfo.ForkVersion) - if err != nil { - return C.CString("[Contract.LuaGovernance] invalid amount: " + err.Error()) - } - if gType == 'S' { - payload = []byte(fmt.Sprintf(`{"Name":"%s"}`, types.Opstake.Cmd())) - } else { - payload = []byte(fmt.Sprintf(`{"Name":"%s"}`, types.Opunstake.Cmd())) - } - case 'V': - amountBig = zeroBig - payload = []byte(fmt.Sprintf(`{"Name":"%s","Args":%s}`, types.OpvoteBP.Cmd(), C.GoString(arg))) - case 'D': - amountBig = zeroBig - payload = []byte(fmt.Sprintf(`{"Name":"%s","Args":%s}`, types.OpvoteDAO.Cmd(), C.GoString(arg))) - } - - cid := []byte(types.AergoSystem) - aid := types.ToAccountID(cid) - scsState, err := getContractState(ctx, cid) - if err != nil { - return C.CString("[Contract.LuaGovernance] getAccount error: " + err.Error()) - } - - curContract := ctx.curContract - - senderState := curContract.callState.accState - receiverState := scsState.accState - - txBody := types.TxBody{ - Amount: amountBig.Bytes(), - Payload: payload, - } - if ctx.blockInfo.ForkVersion >= 2 { - txBody.Account = curContract.contractId - } - - err = types.ValidateSystemTx(&txBody) - if err != nil { - return C.CString("[Contract.LuaGovernance] error: " + err.Error()) - } - - seq, err := setRecoveryPoint(aid, ctx, senderState, scsState, zeroBig, false, false) - if err != nil { - return C.CString("[Contract.LuaGovernance] database error: " + err.Error()) - } - - events, err := system.ExecuteSystemTx(scsState.ctrState, &txBody, senderState, receiverState, ctx.blockInfo) - if err != nil { - rErr := clearRecovery(L, ctx, seq, true) - if rErr != nil { - return C.CString("[Contract.LuaGovernance] recovery error: " + rErr.Error()) - } - return C.CString("[Contract.LuaGovernance] error: " + err.Error()) - } - - if seq == 1 { - err := clearRecovery(L, ctx, seq, false) - if err != nil { - return C.CString("[Contract.LuaGovernance] recovery error: " + err.Error()) - } - } - - ctx.eventCount += int32(len(events)) - ctx.events = append(ctx.events, events...) - - if ctx.lastRecoveryEntry != nil { - if gType == 'S' { - seq, _ = setRecoveryPoint(aid, ctx, senderState, scsState, amountBig, true, false) - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[GOVERNANCE]aid(%s)\n", aid.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("staking : %s\n", amountBig.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", - senderState.Balance().String(), receiverState.Balance().String())) - } - } else if gType == 'U' { - seq, _ = setRecoveryPoint(aid, ctx, receiverState, ctx.curContract.callState, amountBig, true, false) - if ctx.traceFile != nil { - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("[GOVERNANCE]aid(%s)\n", aid.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("snapshot set %d\n", seq)) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("unstaking : %s\n", amountBig.String())) - _, _ = ctx.traceFile.WriteString(fmt.Sprintf("After sender: %s receiver: %s\n", - senderState.Balance().String(), receiverState.Balance().String())) - } - } - } - - return nil -} - -//export luaViewStart -func luaViewStart(service C.int) { - ctx := contexts[service] - ctx.nestedView++ -} - -//export luaViewEnd -func luaViewEnd(service C.int) { - ctx := contexts[service] - ctx.nestedView-- -} - -//export luaCheckView -func luaCheckView(service C.int) C.int { - ctx := contexts[service] - return C.int(ctx.nestedView) -} - -// luaCheckTimeout checks whether the block creation timeout occurred. -// -//export luaCheckTimeout -func luaCheckTimeout(service C.int) C.int { - - if service < BlockFactory { - // Originally, MaxVmService was used instead of maxContext. service - // value can be 2 and decremented by MaxVmService(=2) during VM loading. - // That means the value of service becomes zero after the latter - // adjustment. - // - // This make the VM check block timeout in a unwanted situation. If that - // happens during the chain service is connecting block, the block chain - // becomes out of sync. - service = service + C.int(maxContext) - } - - if service != BlockFactory { - return 0 - } - - ctx := contexts[service] - select { - case <-ctx.execCtx.Done(): - return 1 - default: - return 0 - } -} - -//export luaIsFeeDelegation -func luaIsFeeDelegation(L *LState, service C.int) (C.int, *C.char) { - ctx := contexts[service] - if ctx == nil { - return -1, C.CString("[Contract.LuaIsContract] contract state not found") - } - if ctx.isFeeDelegation { - return 1, nil - } - return 0, nil -} - -//export LuaGetDbHandleSnap -func LuaGetDbHandleSnap(service C.int, snap *C.char) *C.char { - - stateSet := contexts[service] - curContract := stateSet.curContract - callState := curContract.callState - - if stateSet.isQuery != true { - return C.CString("[Contract.LuaSetDbSnap] not permitted in transaction") - } - - if callState.tx != nil { - return C.CString("[Contract.LuaSetDbSnap] transaction already started") - } - - rp, err := strconv.ParseUint(C.GoString(snap), 10, 64) - if err != nil { - return C.CString("[Contract.LuaSetDbSnap] snapshot is not valid" + C.GoString(snap)) - } - - aid := types.ToAccountID(curContract.contractId) - tx, err := beginReadOnly(aid.String(), rp) - if err != nil { - return C.CString("Error Begin SQL Transaction") - } - - callState.tx = tx - return nil -} - -//export LuaGetDbSnapshot -func LuaGetDbSnapshot(service C.int) *C.char { - stateSet := contexts[service] - curContract := stateSet.curContract - - return C.CString(strconv.FormatUint(curContract.rp, 10)) -} - -//export luaGetStaking -func luaGetStaking(service C.int, addr *C.char) (*C.char, C.lua_Integer, *C.char) { - - var ( - ctx *vmContext - scs, namescs *statedb.ContractState - err error - staking *types.Staking - ) - - ctx = contexts[service] - scs, err = statedb.GetSystemAccountState(ctx.bs.StateDB) - if err != nil { - return nil, 0, C.CString(err.Error()) - } - - namescs, err = statedb.GetNameAccountState(ctx.bs.StateDB) - if err != nil { - return nil, 0, C.CString(err.Error()) - } - - staking, err = system.GetStaking(scs, name.GetAddress(namescs, types.ToAddress(C.GoString(addr)))) - if err != nil { - return nil, 0, C.CString(err.Error()) - } - - return C.CString(staking.GetAmountBigInt().String()), C.lua_Integer(staking.When), nil -} - -func sendBalance(sender *state.AccountState, receiver *state.AccountState, amount *big.Int) *C.char { - if err := state.SendBalance(sender, receiver, amount); err != nil { - return C.CString("[Contract.sendBalance] insufficient balance: " + - sender.Balance().String() + " : " + amount.String()) - } - return nil -} diff --git a/contract/vm_direct/vm_direct.go b/contract/vm_direct/vm_direct.go index 98cadc329..052927ecf 100644 --- a/contract/vm_direct/vm_direct.go +++ b/contract/vm_direct/vm_direct.go @@ -32,10 +32,6 @@ const ( ChainTypeUnitTest ) -const ( - lStateMaxSize = 10 * 7 -) - var ( logger *log.Logger ) @@ -134,7 +130,7 @@ func LoadDummyChainEx(chainType ChainType) (*DummyChain, error) { contract.LoadTestDatabase(dataPath) contract.SetStateSQLMaxDBSize(1024) - contract.StartLStateFactory(lStateMaxSize, config.GetDefaultNumLStateClosers(), 1) + contract.StartVMPool(contract.MaxPossibleCallDepth()) contract.InitContext(3) // To pass the governance tests. diff --git a/contract/vm_dummy/test_files/feature_pcall_rollback_4a.lua b/contract/vm_dummy/test_files/feature_pcall_rollback_4a.lua index d8765e13a..a7867967a 100644 --- a/contract/vm_dummy/test_files/feature_pcall_rollback_4a.lua +++ b/contract/vm_dummy/test_files/feature_pcall_rollback_4a.lua @@ -4,19 +4,21 @@ state.var { values = state.map() } -function constructor(resolver_address, contract_name) +function constructor(resolver_address, contract_name, use_db) -- initialize state variables resolver:set(resolver_address) name:set(contract_name) -- initialize db - db.exec("create table config (value integer primary key) without rowid") - db.exec("insert into config values (0)") - db.exec([[create table products ( - id integer primary key, - name text not null, - price real) - ]]) - db.exec("insert into products (name,price) values ('first', 1234.56)") + if use_db then + db.exec("create table config (value integer primary key) without rowid") + db.exec("insert into config values (0)") + db.exec([[create table products ( + id integer primary key, + name text not null, + price real) + ]]) + db.exec("insert into products (name,price) values ('first', 1234.56)") + end end function resolve(name) diff --git a/contract/vm_dummy/test_files/feature_pcall_rollback_4b.lua b/contract/vm_dummy/test_files/feature_pcall_rollback_4b.lua index bc1c40929..2a199410b 100644 --- a/contract/vm_dummy/test_files/feature_pcall_rollback_4b.lua +++ b/contract/vm_dummy/test_files/feature_pcall_rollback_4b.lua @@ -4,19 +4,21 @@ state.var { values = state.map() } -function constructor(resolver_address, contract_name) +function constructor(resolver_address, contract_name, use_db) -- initialize state variables resolver:set(resolver_address) name:set(contract_name) -- initialize db - db.exec("create table config (value integer primary key) without rowid") - db.exec("insert into config values (0)") - db.exec([[create table products ( - id integer primary key, - name text not null, - price real) - ]]) - db.exec("insert into products (name,price) values ('first', 1234.56)") + if use_db then + db.exec("create table config (value integer primary key) without rowid") + db.exec("insert into config values (0)") + db.exec([[create table products ( + id integer primary key, + name text not null, + price real) + ]]) + db.exec("insert into products (name,price) values ('first', 1234.56)") + end end function resolve(name) diff --git a/contract/vm_dummy/test_files/feature_pcall_rollback_4c.lua b/contract/vm_dummy/test_files/feature_pcall_rollback_4c.lua index 4e0ed37f2..cd06aee92 100644 --- a/contract/vm_dummy/test_files/feature_pcall_rollback_4c.lua +++ b/contract/vm_dummy/test_files/feature_pcall_rollback_4c.lua @@ -4,19 +4,21 @@ state.var { values = state.map() } -function constructor(resolver_address, contract_name) +function constructor(resolver_address, contract_name, use_db) -- initialize state variables resolver:set(resolver_address) name:set(contract_name) -- initialize db - db.exec("create table config (value integer primary key) without rowid") - db.exec("insert into config values (0)") - db.exec([[create table products ( - id integer primary key, - name text not null, - price real) - ]]) - db.exec("insert into products (name,price) values ('first', 1234.56)") + if use_db then + db.exec("create table config (value integer primary key) without rowid") + db.exec("insert into config values (0)") + db.exec([[create table products ( + id integer primary key, + name text not null, + price real) + ]]) + db.exec("insert into products (name,price) values ('first', 1234.56)") + end end function resolve(name) diff --git a/contract/vm_dummy/test_files/infiniteloop.lua b/contract/vm_dummy/test_files/infiniteloop.lua index 8739e16d7..61bc937cc 100644 --- a/contract/vm_dummy/test_files/infiniteloop.lua +++ b/contract/vm_dummy/test_files/infiniteloop.lua @@ -1,4 +1,4 @@ -function infiniteLoop() +function infinite_loop() local t = 0 while true do t = t + 1 @@ -6,16 +6,25 @@ function infiniteLoop() return t end -function infiniteCall() - infiniteCall() +function infinite_call() + infinite_call() end -function catch() - return pcall(infiniteLoop) +function catch_loop() + return pcall(infinite_loop) end -function contract_catch() - return contract.pcall(infiniteLoop) +function catch_call() + return pcall(infinite_call) end -abi.register(infiniteLoop, infiniteCall, catch, contract_catch) +function contract_catch_loop() + return contract.pcall(infinite_loop) +end + +function contract_catch_call() + return contract.pcall(infinite_call) +end + +abi.register(infinite_loop, catch_loop, contract_catch_loop) +abi.register(infinite_call, catch_call, contract_catch_call) diff --git a/contract/vm_dummy/test_files/type_maxstring.lua b/contract/vm_dummy/test_files/type_maxstring.lua index cf7d63c48..391326a92 100644 --- a/contract/vm_dummy/test_files/type_maxstring.lua +++ b/contract/vm_dummy/test_files/type_maxstring.lua @@ -1,17 +1,71 @@ -function oom() +function oom_string() local s = "hello" - while 1 do s = s .. s end end -function p() - pcall(oom) +function oom_table1() + local t = {} + for i = 1, 10000000 do + t[i] = i + end +end + +function oom_table2() + local t = {} + for i = 1, 10000000 do + local str = "key" .. tostring(i) + t[str] = str + end +end + +function oom_global() + local s1 = "hello" + local s2 = "hello" + local s3 = "hello" + local s4 = "hello" + while 1 do + s1 = s1 .. s1 + s2 = s2 .. s2 + s3 = s3 .. s3 + s4 = s4 .. s4 + end +end + +function pcall_string() + pcall(oom_string) +end + +function pcall_table1() + pcall(oom_table1) +end + +function pcall_table2() + pcall(oom_table2) +end + +function pcall_global() + pcall(oom_global) +end + +function contract_pcall_string() + contract.pcall(oom_string) +end + +function contract_pcall_table1() + contract.pcall(oom_table1) +end + +function contract_pcall_table2() + contract.pcall(oom_table2) end -function cp() - contract.pcall(oom) +function contract_pcall_global() + contract.pcall(oom_global) end -abi.register(oom, p, cp) +abi.register(oom_string, pcall_string, contract_pcall_string) +abi.register(oom_table1, pcall_table1, contract_pcall_table1) +abi.register(oom_table2, pcall_table2, contract_pcall_table2) +abi.register(oom_global, pcall_global, contract_pcall_global) diff --git a/contract/vm_dummy/vm_dummy.go b/contract/vm_dummy/vm_dummy.go index b867941ad..ec17528d4 100644 --- a/contract/vm_dummy/vm_dummy.go +++ b/contract/vm_dummy/vm_dummy.go @@ -16,7 +16,6 @@ import ( "github.com/aergoio/aergo-lib/db" "github.com/aergoio/aergo-lib/log" "github.com/aergoio/aergo/v2/cmd/aergoluac/util" - "github.com/aergoio/aergo/v2/config" "github.com/aergoio/aergo/v2/contract" "github.com/aergoio/aergo/v2/contract/system" "github.com/aergoio/aergo/v2/fee" @@ -30,11 +29,10 @@ import ( var ( logger *log.Logger + prevPubNet bool ) const ( - lStateMaxSize = 10 * 7 - dummyBlockIntervalSec = 1 dummyBlockExecTimeMs = (dummyBlockIntervalSec * 1000) >> 2 ) @@ -45,6 +43,7 @@ func init() { type DummyChain struct { HardforkVersion int32 + PubNet bool sdb *state.ChainStateDB bestBlock *types.Block cBlock *types.Block @@ -55,7 +54,6 @@ type DummyChain struct { testReceiptDB db.DB tmpDir string timeout int - clearLState func() gasPrice *big.Int timestamp int64 } @@ -75,32 +73,58 @@ func SetTimeout(timeout int) DummyChainOptions { } } +// used on brick + func SetPubNet() DummyChainOptions { return func(dc *DummyChain) { + contract.PubNet = true + dc.PubNet = true + } +} - // public and private chains have different features. - // private chains have the db module and public ones don't. - // this is why we need to flush all Lua states and recreate - // them when moving to and from public chain. +func SetPrivNet() DummyChainOptions { + return func(dc *DummyChain) { + contract.PubNet = false + dc.PubNet = false + } +} - contract.PubNet = true - fee.DisableZeroFee() - contract.FlushLStates() +// used on tests - dc.clearLState = func() { - contract.PubNet = false - fee.EnableZeroFee() - contract.FlushLStates() - } +func RunOnPubNet() DummyChainOptions { + return func(dc *DummyChain) { + dc.PubNet = true + } +} + +func RunOnPrivNet() DummyChainOptions { + return func(dc *DummyChain) { + dc.PubNet = false + } +} + +func RunOnAllNets() DummyChainOptions { + return func(dc *DummyChain) { + dc.PubNet = contract.PubNet } } func LoadDummyChain(opts ...DummyChainOptions) (*DummyChain, error) { + + // skip test if pubnet is different + bc := &DummyChain{} + for _, opt := range opts { + opt(bc) + } + if bc.PubNet != contract.PubNet { + return nil, nil + } + dataPath, err := os.MkdirTemp("", "data") if err != nil { return nil, err } - bc := &DummyChain{ + bc = &DummyChain{ sdb: state.NewChainStateDB(), tmpDir: dataPath, gasPrice: types.NewAmount(1, types.Aer), @@ -129,11 +153,8 @@ func LoadDummyChain(opts ...DummyChainOptions) (*DummyChain, error) { bc.testReceiptDB = db.NewDB(db.MemoryImpl, path.Join(dataPath, "receiptDB")) contract.LoadTestDatabase(dataPath) // sql database contract.SetStateSQLMaxDBSize(1024) - contract.StartLStateFactory(lStateMaxSize, config.GetDefaultNumLStateClosers(), 1) contract.InitContext(3) - bc.HardforkVersion = 2 - // To pass the governance tests. types.InitGovernance("dpos", true) @@ -141,19 +162,40 @@ func LoadDummyChain(opts ...DummyChainOptions) (*DummyChain, error) { scs, err := statedb.GetSystemAccountState(bc.sdb.GetStateDB()) system.InitSystemParams(scs, 3) - fee.EnableZeroFee() + // set default values + bc.HardforkVersion = 2 + bc.PubNet = false + // process options for _, opt := range opts { opt(bc) } + + if !contract.VmPoolStarted { + contract.CurrentForkVersion = bc.HardforkVersion + contract.StartVMPool(contract.MaxPossibleCallDepth()) + } else if (contract.CurrentForkVersion != bc.HardforkVersion) || (contract.PubNet != prevPubNet) { + // public and private chains have different features. + // private chains have the db module and public ones don't. + // this is why we need to flush all Lua VM instances and + // recreate them when moving to and from public chain. + contract.CurrentForkVersion = bc.HardforkVersion + contract.FlushVmInstances() + } + + prevPubNet = contract.PubNet + + if contract.PubNet { + fee.DisableZeroFee() + } else { + fee.EnableZeroFee() + } + return bc, nil } func (bc *DummyChain) Release() { bc.testReceiptDB.Close() - if bc.clearLState != nil { - bc.clearLState() - } _ = os.RemoveAll(bc.tmpDir) } @@ -593,7 +635,7 @@ func (l *luaTxDeploy) run(execCtx context.Context, bs *state.BlockState, bc *Dum // compile the plain code to bytecode payload := util.LuaCodePayload(l._payload) code := string(payload.Code()) - byteCode, err := contract.Compile(code, nil) + byteCode, err := contract.Compile(code, false) if err != nil { return err } diff --git a/contract/vm_dummy/vm_dummy_pub_test.go b/contract/vm_dummy/vm_dummy_pub_test.go index 69416ef87..da8985eed 100644 --- a/contract/vm_dummy/vm_dummy_pub_test.go +++ b/contract/vm_dummy/vm_dummy_pub_test.go @@ -28,9 +28,16 @@ func TestContractSendF(t *testing.T) { code := readLuaCode(t, "contract_sendf_1.lua") code2 := readLuaCode(t, "contract_sendf_2.lua") - for version := int32(3); version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version), SetPubNet()) + // skip if current version is less than 3 + if currentVersion < 3 { + t.Skipf("%s: skip version less than 3", t.Name()) + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -45,7 +52,7 @@ func TestContractSendF(t *testing.T) { require.NoErrorf(t, err, "failed to connect new block") r := bc.GetReceipt(tx.Hash()) - expectedGas := map[int32]int64{3: 105087, 4: 105087}[version] + expectedGas := map[int32]int64{3: 105087, 4: 105087}[currentVersion] assert.Equalf(t, expectedGas, int64(r.GetGasUsed()), "gas used not equal") state, err := bc.GetAccountState("test2") @@ -56,12 +63,12 @@ func TestContractSendF(t *testing.T) { require.NoErrorf(t, err, "failed to connect new block") r = bc.GetReceipt(tx.Hash()) - expectedGas = map[int32]int64{3: 105179, 4: 105755}[version] + expectedGas = map[int32]int64{3: 105179, 4: 105755}[currentVersion] assert.Equalf(t, expectedGas, int64(r.GetGasUsed()), "gas used not equal") state, err = bc.GetAccountState("test2") assert.Equalf(t, int64(6), state.GetBalanceBigInt().Int64(), "balance state not equal") - } + } func TestGasPerFunction(t *testing.T) { @@ -70,8 +77,11 @@ func TestGasPerFunction(t *testing.T) { var err error code := readLuaCode(t, "gas_per_function.lua") - bc, err := LoadDummyChain(SetPubNet()) - assert.NoError(t, err) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet()) + require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -82,7 +92,7 @@ func TestGasPerFunction(t *testing.T) { NewLuaTxDeploy("user", "contract_v3", 0, code), NewLuaTxDeploy("user", "contract_v4", 0, code), ) - assert.NoError(t, err) + require.NoError(t, err) // transfer funds to the contracts err = bc.ConnectBlock( @@ -90,7 +100,7 @@ func TestGasPerFunction(t *testing.T) { NewLuaTxCall("user", "contract_v3", uint64(10e18), `{"Name":"default"}`), NewLuaTxCall("user", "contract_v4", uint64(10e18), `{"Name":"default"}`), ) - assert.NoError(t, err, "sending funds to contracts") + require.NoError(t, err, "sending funds to contracts") tests_v2 := []struct { funcName string @@ -463,8 +473,7 @@ func TestGasPerFunction(t *testing.T) { {"contract.event", "", 0, 163452}, } - // set the hard fork version - bc.HardforkVersion = 2 + if currentVersion == 2 { // iterate over the tests for _, test := range tests_v2 { @@ -481,7 +490,7 @@ func TestGasPerFunction(t *testing.T) { } tx := NewLuaTxCall("user", "contract_v2", uint64(amount), payload) err = bc.ConnectBlock(tx) - assert.NoError(t, err, "while executing %s", funcName) + require.NoError(t, err, "while executing %s", funcName) usedGas := bc.GetReceipt(tx.Hash()).GetGasUsed() assert.Equal(t, expectedGas, int64(usedGas), "wrong used gas for %s", funcName) @@ -493,8 +502,7 @@ func TestGasPerFunction(t *testing.T) { //fmt.Printf("add_test \"%s\" %d\n", funcName, usedGas) } - // set the hard fork version - bc.HardforkVersion = 3 + } else if currentVersion == 3 { // iterate over the tests for _, test := range tests_v3 { @@ -511,7 +519,7 @@ func TestGasPerFunction(t *testing.T) { } tx := NewLuaTxCall("user", "contract_v3", uint64(amount), payload) err = bc.ConnectBlock(tx) - assert.NoError(t, err, "while executing %s", funcName) + require.NoError(t, err, "while executing %s", funcName) usedGas := bc.GetReceipt(tx.Hash()).GetGasUsed() assert.Equal(t, expectedGas, int64(usedGas), "wrong used gas for %s", funcName) @@ -523,8 +531,7 @@ func TestGasPerFunction(t *testing.T) { //fmt.Printf("add_test \"%s\" %d\n", funcName, usedGas) } - // set the hard fork version - bc.HardforkVersion = 4 + } else if currentVersion == 4 { // iterate over the tests for _, test := range tests_v4 { @@ -541,7 +548,7 @@ func TestGasPerFunction(t *testing.T) { } tx := NewLuaTxCall("user", "contract_v4", uint64(amount), payload) err = bc.ConnectBlock(tx) - assert.NoError(t, err, "while executing %s", funcName) + require.NoError(t, err, "while executing %s", funcName) usedGas := bc.GetReceipt(tx.Hash()).GetGasUsed() assert.Equal(t, expectedGas, int64(usedGas), "wrong used gas for %s", funcName) @@ -553,6 +560,8 @@ func TestGasPerFunction(t *testing.T) { //fmt.Printf("add_test \"%s\" %d\n", funcName, usedGas) } + } + } func TestGasHello(t *testing.T) { @@ -561,22 +570,30 @@ func TestGasHello(t *testing.T) { var err error code := readLuaCode(t, "contract_hello.lua") - err = expectGas(code, 0, `"hello"`, `"world"`, 100000, SetHardForkVersion(1)) - assert.NoError(t, err) - - err = expectGas(code, 0, `"hello"`, `"w"`, 101203+3*1, SetHardForkVersion(2)) - assert.NoError(t, err) - err = expectGas(code, 0, `"hello"`, `"wo"`, 101203+3*2, SetHardForkVersion(2)) - assert.NoError(t, err) - err = expectGas(code, 0, `"hello"`, `"wor"`, 101203+3*3, SetHardForkVersion(2)) - assert.NoError(t, err) - err = expectGas(code, 0, `"hello"`, `"worl"`, 101203+3*4, SetHardForkVersion(2)) - assert.NoError(t, err) - err = expectGas(code, 0, `"hello"`, `"world"`, 101203+3*5, SetHardForkVersion(2)) - assert.NoError(t, err) - - err = expectGas(code, 0, `"hello"`, `"world"`, 101203+3*5, SetHardForkVersion(3)) - assert.NoError(t, err) + if currentVersion == 1 { + + err = expectGas(code, 0, `"hello"`, `"world"`, 100000) + assert.NoError(t, err) + + } else if currentVersion == 2 { + + err = expectGas(code, 0, `"hello"`, `"w"`, 101203+3*1) + assert.NoError(t, err) + err = expectGas(code, 0, `"hello"`, `"wo"`, 101203+3*2) + assert.NoError(t, err) + err = expectGas(code, 0, `"hello"`, `"wor"`, 101203+3*3) + assert.NoError(t, err) + err = expectGas(code, 0, `"hello"`, `"worl"`, 101203+3*4) + assert.NoError(t, err) + err = expectGas(code, 0, `"hello"`, `"world"`, 101203+3*5) + assert.NoError(t, err) + + } else if currentVersion >= 3 { + + err = expectGas(code, 0, `"hello"`, `"world"`, 101203+3*5) + assert.NoError(t, err) + + } } func TestGasDeploy(t *testing.T) { @@ -585,17 +602,28 @@ func TestGasDeploy(t *testing.T) { var err error code := readLuaCode(t, "gas_deploy.lua") - // err = expectGas(code, 0, `"testPcall"`, ``, 0, SetHardForkVersion(0)) - // assert.NoError(t, err) + if currentVersion <= 1 { + + err = expectGas(code, 0, `"testPcall"`, ``, 0) + assert.NoError(t, err) + + } else if currentVersion == 2 { + + err = expectGas(code, 0, `"testPcall"`, ``, 117861) + assert.NoError(t, err) + + } else if currentVersion == 3 { + + err = expectGas(code, 0, `"testPcall"`, ``, 117861) + assert.NoError(t, err) - err = expectGas(code, 0, `"testPcall"`, ``, 117861, SetHardForkVersion(2)) - assert.NoError(t, err) + } else if currentVersion == 4 { - err = expectGas(code, 0, `"testPcall"`, ``, 117861, SetHardForkVersion(3)) - assert.NoError(t, err) + err = expectGas(code, 0, `"testPcall"`, ``, 118350) + assert.NoError(t, err) + + } - err = expectGas(code, 0, `"testPcall"`, ``, 118350, SetHardForkVersion(4)) - assert.NoError(t, err) } func TestGasOp(t *testing.T) { @@ -604,17 +632,27 @@ func TestGasOp(t *testing.T) { var err error code := readLuaCode(t, "gas_op.lua") - err = expectGas(string(code), 0, `"main"`, ``, 100000, SetHardForkVersion(0)) - assert.NoError(t, err) + if currentVersion <= 1 { + + err = expectGas(string(code), 0, `"main"`, ``, 100000) + assert.NoError(t, err) + + } else if currentVersion == 2 { + + err = expectGas(string(code), 0, `"main"`, ``, 117610) + assert.NoError(t, err) + + } else if currentVersion == 3 { - err = expectGas(string(code), 0, `"main"`, ``, 117610, SetHardForkVersion(2)) - assert.NoError(t, err) + err = expectGas(string(code), 0, `"main"`, ``, 117610) + assert.NoError(t, err) - err = expectGas(string(code), 0, `"main"`, ``, 117610, SetHardForkVersion(3)) - assert.NoError(t, err) + } else if currentVersion == 4 { - err = expectGas(string(code), 0, `"main"`, ``, 120832, SetHardForkVersion(4)) - assert.NoError(t, err) + err = expectGas(string(code), 0, `"main"`, ``, 120832) + assert.NoError(t, err) + + } } func TestGasBF(t *testing.T) { @@ -624,17 +662,27 @@ func TestGasBF(t *testing.T) { code2 := readLuaCode(t, "gas_bf_v2.lua") code4 := readLuaCode(t, "gas_bf_v4.lua") - // err = expectGas(t, string(code), 0, `"main"`, ``, 100000, SetHardForkVersion(1), SetTimeout(500)) - // assert.NoError(t, err) + if currentVersion <= 1 { + + // err = expectGas(t, string(code), 0, `"main"`, ``, 100000, SetTimeout(500)) + // assert.NoError(t, err) + + } else if currentVersion == 2 { + + err = expectGas(string(code2), 0, `"main"`, ``, 47456244, SetTimeout(500)) + assert.NoError(t, err) + + } else if currentVersion == 3 { + + err = expectGas(string(code2), 0, `"main"`, ``, 47456046, SetTimeout(500)) + assert.NoError(t, err) - err = expectGas(string(code2), 0, `"main"`, ``, 47456244, SetHardForkVersion(2), SetTimeout(500)) - assert.NoError(t, err) + } else if currentVersion == 4 { - err = expectGas(string(code2), 0, `"main"`, ``, 47456046, SetHardForkVersion(3), SetTimeout(500)) - assert.NoError(t, err) + err = expectGas(string(code4), 0, `"main"`, ``, 47342481, SetTimeout(500)) + assert.NoError(t, err) - err = expectGas(string(code4), 0, `"main"`, ``, 47342481, SetHardForkVersion(4), SetTimeout(500)) - assert.NoError(t, err) + } } func TestGasLuaCryptoVerifyProof(t *testing.T) { @@ -642,37 +690,48 @@ func TestGasLuaCryptoVerifyProof(t *testing.T) { code := readLuaCode(t, "feature_crypto_verify_proof.lua") - // v2 raw - err := expectGas(string(code), 0, `"verifyProofRaw"`, ``, 154137, SetHardForkVersion(2)) - assert.NoError(t, err) + if currentVersion == 2 { + + // v2 raw + err := expectGas(string(code), 0, `"verifyProofRaw"`, ``, 154137) + assert.NoError(t, err) + + // v2 hex + err = expectGas(string(code), 0, `"verifyProofHex"`, ``, 108404) + assert.NoError(t, err) - // v2 hex - err = expectGas(string(code), 0, `"verifyProofHex"`, ``, 108404, SetHardForkVersion(2)) - assert.NoError(t, err) + } else if currentVersion == 3 { - // v3 raw - err = expectGas(string(code), 0, `"verifyProofRaw"`, ``, 154137, SetHardForkVersion(3)) - assert.NoError(t, err) + // v3 raw + err := expectGas(string(code), 0, `"verifyProofRaw"`, ``, 154137) + assert.NoError(t, err) - // v3 hex - err = expectGas(string(code), 0, `"verifyProofHex"`, ``, 108404, SetHardForkVersion(3)) - assert.NoError(t, err) + // v3 hex + err = expectGas(string(code), 0, `"verifyProofHex"`, ``, 108404) + assert.NoError(t, err) - // v4 raw - err = expectGas(string(code), 0, `"verifyProofRaw"`, ``, 160281, SetHardForkVersion(4)) - assert.NoError(t, err) + } else if currentVersion == 4 { - // v4 hex - err = expectGas(string(code), 0, `"verifyProofHex"`, ``, 108404, SetHardForkVersion(4)) - assert.NoError(t, err) + // v4 raw + err := expectGas(string(code), 0, `"verifyProofRaw"`, ``, 160281) + assert.NoError(t, err) + + // v4 hex + err = expectGas(string(code), 0, `"verifyProofHex"`, ``, 108404) + assert.NoError(t, err) + + } } func expectGas(contractCode string, amount int64, funcName, funcArgs string, expectGas int64, opt ...DummyChainOptions) error { // append set pubnet - bc, err := LoadDummyChain(append(opt, SetPubNet())...) + bc, err := LoadDummyChain(append(opt, RunOnPubNet(), SetHardForkVersion(currentVersion))...) if err != nil { return err } + if bc == nil { + return nil + } defer bc.Release() if err = bc.ConnectBlock( @@ -724,9 +783,15 @@ func TestTypeInvalidKey(t *testing.T) { code := readLuaCode(t, "type_invalidkey.lua") - for version := int32(3); version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + if currentVersion < 3 { + t.Skip() + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "invalidkey", 0, code)) @@ -752,7 +817,7 @@ func TestTypeInvalidKey(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "invalidkey", 0, `{"Name":"key_nil"}`).Fail("invalid key type: 'nil', state.map: 'h'")) require.NoErrorf(t, err, "failed to call tx") - } + } func TestTypeBigTable(t *testing.T) { @@ -764,9 +829,15 @@ func TestTypeBigTable(t *testing.T) { code := readLuaCode(t, "type_bigtable_1.lua") code2 := readLuaCode(t, "type_bigtable_2.lua") - for version := int32(3); version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + if currentVersion < 3 { + t.Skip() + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "big", 0, code)) @@ -786,5 +857,5 @@ func TestTypeBigTable(t *testing.T) { } err = bc.ConnectBlock(NewLuaTxCall("user1", "big20", 0, `{"Name": "inserts"}`).Fail("database or disk is full")) require.NoErrorf(t, err, "failed to call tx") - } + } diff --git a/contract/vm_dummy/vm_dummy_test.go b/contract/vm_dummy/vm_dummy_test.go index 5cd69ecca..9422cbe9b 100644 --- a/contract/vm_dummy/vm_dummy_test.go +++ b/contract/vm_dummy/vm_dummy_test.go @@ -23,25 +23,48 @@ const min_version int32 = 2 const max_version int32 = 4 const min_version_multicall int32 = 4 +var currentVersion int32 + +func TestMain(m *testing.M) { + for version := min_version; version <= max_version; version++ { + currentVersion = version + contract.PubNet = true + fmt.Println("-------------------------------------------------------") + fmt.Printf("Running tests for hardfork %d (PubNet) \n", currentVersion) + fmt.Println("-------------------------------------------------------") + m.Run() + contract.PubNet = false + fmt.Println("-------------------------------------------------------") + fmt.Printf("Running tests for hardfork %d (PrivateNet) \n", currentVersion) + fmt.Println("-------------------------------------------------------") + m.Run() + } +} + func TestDisabledFunctions(t *testing.T) { code := readLuaCode(t, "disabled-functions.lua") - for version := int32(4); version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version), SetPubNet()) + if currentVersion < 4 { + t.Skipf("skipping test for version %d", currentVersion) + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( NewLuaTxAccount("user", 1, types.Aergo), NewLuaTxDeploy("user", "test", 0, code), ) - assert.NoErrorf(t, err, "failed to deploy contract") + require.NoErrorf(t, err, "failed to deploy contract") err = bc.ConnectBlock( NewLuaTxCall("user", "test", 0, `{"Name":"check_disabled_functions","Args":[]}`), ) - assert.NoErrorf(t, err, "failed execution") - } + require.NoErrorf(t, err, "failed execution") } func TestMaxCallDepth(t *testing.T) { @@ -51,17 +74,30 @@ func TestMaxCallDepth(t *testing.T) { // this contract stores the address of the next contract to be called code3 := readLuaCode(t, "maxcalldepth_3.lua") - for version := int32(3); version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version), SetPubNet()) + // skip if current version is less than 3 + if currentVersion < 3 { + t.Skipf("skipping test for version %d", currentVersion) + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet(), SetTimeout(1000)) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( NewLuaTxAccount("user", 1, types.Aergo), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to create account") + + var maxCallDepth int + //if currentVersion >= 5 { + maxCallDepth = 20 + //} else { + // maxCallDepth = 64 + //} + numToDeploy := maxCallDepth + 1 /* // deploy 2 identical contracts @@ -69,32 +105,22 @@ func TestMaxCallDepth(t *testing.T) { NewLuaTxDeploy("user", "c1", 0, definition1), NewLuaTxDeploy("user", "c2", 0, definition1), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to deploy contract") // call first contract - recursion depth 64 err = bc.ConnectBlock( NewLuaTxCall("user", "c1", 0, `{"Name":"call_me", "Args":[1, 64]}`), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") // check state err = bc.Query("c1", `{"Name":"check_state"}`, "", "true") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") // query view err = bc.Query("c1", `{"Name":"get_total_calls"}`, "", "[64,64]") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") for i := 1; i <= 64; i++ { err = bc.Query("c1", fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, i), "", fmt.Sprintf("%d", i)) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } // call second contract - recursion depth 66 @@ -102,165 +128,115 @@ func TestMaxCallDepth(t *testing.T) { NewLuaTxCall("user", "c2", 0, `{"Name":"call_me", "Args":[1, 66]}`). Fail("exceeded the maximum call depth"), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") // check state - should fail err = bc.Query("c2", `{"Name":"check_state"}`, "", "") - if err == nil { - t.Error("should fail") - } + assert.Errorf(t, err, "should fail") // query view - must return nil err = bc.Query("c2", `{"Name":"get_total_calls"}`, "", "[null,null]") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") for i := 1; i <= 64; i++ { err = bc.Query("c2", fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, i), "", "null") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } */ - // deploy 66 identical contracts using definition2 - for i := 1; i <= 66; i++ { + // deploy N+1 identical contracts using definition2 + for i := 1; i <= numToDeploy; i++ { err = bc.ConnectBlock( NewLuaTxDeploy("user", fmt.Sprintf("c2%d", i), 0, code2), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to deploy contract") } - // deploy 66 identical contracts using definition3 - for i := 1; i <= 66; i++ { + // deploy N+1 identical contracts using definition3 + for i := 1; i <= numToDeploy; i++ { err = bc.ConnectBlock( NewLuaTxDeploy("user", fmt.Sprintf("c3%d", i), 0, code3), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to deploy contract") } // build a list of contract IDs, used to call the first contract - contracts := make([]string, 64) + contracts := make([]string, maxCallDepth) contracts_str := []byte("") - for i := 1; i <= 64; i++ { + for i := 1; i <= maxCallDepth; i++ { contracts[i-1] = StrToAddress(fmt.Sprintf("c2%d", i)) } contracts_str, err = json.Marshal(contracts) - if err != nil { - t.Error(err) - } - // call first contract - recursion depth 64 + require.NoErrorf(t, err, "failed to create contract list") + // call first contract - recursion depth = maxCallDepth err = bc.ConnectBlock( - NewLuaTxCall("user", "c2"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, 64]}`, string(contracts_str))), + NewLuaTxCall("user", "c2"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, %d]}`, string(contracts_str), maxCallDepth)), ) - if err != nil { - t.Error(err) - } - // check state on all the 64 contracts (query total calls and call info) - for i := 1; i <= 64; i++ { + require.NoErrorf(t, err, "failed to call tx") + // check state on all the maxCallDepth contracts (query total calls and call info) + for i := 1; i <= maxCallDepth; i++ { err = bc.Query(fmt.Sprintf("c2%d", i), `{"Name":"get_total_calls"}`, "", "1") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") //err = bc.Query(fmt.Sprintf("c2%d", i), fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, i), "", fmt.Sprintf("%d", i)) err = bc.Query(fmt.Sprintf("c2%d", i), `{"Name":"get_call_info", "Args":["1"]}`, "", fmt.Sprintf("%d", i)) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } - // add the 66th contract to the list - contracts = append(contracts, StrToAddress(fmt.Sprintf("c2%d", 6))) + // add the N+1 contract to the list + contracts = append(contracts, StrToAddress(fmt.Sprintf("c2%d", numToDeploy))) contracts_str, err = json.Marshal(contracts) - if err != nil { - t.Error(err) - } - // call first contract - recursion depth 66 + require.NoErrorf(t, err, "failed to create contract list") + // call first contract - recursion depth = maxCallDepth + 1 err = bc.ConnectBlock( - NewLuaTxCall("user", "c2"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, 66]}`, string(contracts_str))).Fail("exceeded the maximum call depth"), + NewLuaTxCall("user", "c2"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, %d]}`, string(contracts_str), maxCallDepth+1)).Fail("exceeded the maximum call depth"), ) - if err != nil { - t.Error(err) - } - // check state on all the 64 contracts (query total calls and call info) - for i := 1; i <= 64; i++ { + require.NoErrorf(t, err, "failed to call tx") + // check state on all the maxCallDepth contracts (query total calls and call info) + for i := 1; i <= maxCallDepth; i++ { err = bc.Query(fmt.Sprintf("c2%d", i), `{"Name":"get_total_calls"}`, "", "1") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") err = bc.Query(fmt.Sprintf("c2%d", i), `{"Name":"get_call_info", "Args":["1"]}`, "", fmt.Sprintf("%d", i)) - if err != nil { - t.Error(err) - } - } - // check state on the 66th contract (query total calls and call info) - err = bc.Query("c2"+fmt.Sprintf("%d", 66), `{"Name":"get_total_calls"}`, "", "null") - if err != nil { - t.Error(err) - } - err = bc.Query("c2"+fmt.Sprintf("%d", 66), `{"Name":"get_call_info", "Args":["1"]}`, "", "null") - if err != nil { - t.Error(err) + assert.NoErrorf(t, err, "failed to query") } + // check state on the N+1 contract (query total calls and call info) + err = bc.Query("c2"+fmt.Sprintf("%d", maxCallDepth+1), `{"Name":"get_total_calls"}`, "", "null") + assert.NoErrorf(t, err, "failed to query") + err = bc.Query("c2"+fmt.Sprintf("%d", maxCallDepth+1), `{"Name":"get_call_info", "Args":["1"]}`, "", "null") + assert.NoErrorf(t, err, "failed to query") // set next_contract for each contract - for i := 1; i <= 66; i++ { + for i := 1; i <= numToDeploy; i++ { err = bc.ConnectBlock( NewLuaTxCall("user", fmt.Sprintf("c3%d", i), 0, fmt.Sprintf(`{"Name":"set_next_contract", "Args":["%s"]}`, StrToAddress(fmt.Sprintf("c3%d", i+1)))), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") } - // call first contract - recursion depth 64 + // call first contract - recursion depth = maxCallDepth err = bc.ConnectBlock( - NewLuaTxCall("user", "c3"+fmt.Sprintf("%d", 1), 0, `{"Name":"call_me", "Args":[1, 64]}`), + NewLuaTxCall("user", "c3"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[1, %d]}`, maxCallDepth)), ) - if err != nil { - t.Error(err) - } - // check state on all the 64 contracts (query total calls and call info) - for i := 1; i <= 64; i++ { + require.NoErrorf(t, err, "failed to call tx") + // check state on all the N contracts (query total calls and call info) + for i := 1; i <= maxCallDepth; i++ { err = bc.Query(fmt.Sprintf("c3%d", i), `{"Name":"get_total_calls"}`, "", "1") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") err = bc.Query(fmt.Sprintf("c3%d", i), `{"Name":"get_call_info", "Args":["1"]}`, "", fmt.Sprintf("%d", i)) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } - // call first contract - recursion depth 66 + // call first contract - recursion depth = maxCallDepth + 1 err = bc.ConnectBlock( - NewLuaTxCall("user", "c3"+fmt.Sprintf("%d", 1), 0, `{"Name":"call_me", "Args":[1, 66]}`).Fail("exceeded the maximum call depth"), + NewLuaTxCall("user", "c3"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[1, %d]}`, maxCallDepth+1)).Fail("exceeded the maximum call depth"), ) - if err != nil { - t.Error(err) - } - // check state on all the 64 contracts (query total calls and call info) - for i := 1; i <= 64; i++ { + require.NoErrorf(t, err, "failed to call tx") + // check state on all the N contracts (query total calls and call info) + for i := 1; i <= maxCallDepth; i++ { err = bc.Query(fmt.Sprintf("c3%d", i), `{"Name":"get_total_calls"}`, "", "1") - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") err = bc.Query(fmt.Sprintf("c3%d", i), `{"Name":"get_call_info", "Args":["1"]}`, "", fmt.Sprintf("%d", i)) - if err != nil { - t.Error(err) - } - } - // check state on the 66th contract (query total calls and call info) - err = bc.Query("c3"+fmt.Sprintf("%d", 66), `{"Name":"get_total_calls"}`, "", "null") - if err != nil { - t.Error(err) - } - err = bc.Query("c3"+fmt.Sprintf("%d", 66), `{"Name":"get_call_info", "Args":["1"]}`, "", "null") - if err != nil { - t.Error(err) + assert.NoErrorf(t, err, "failed to query") } + // check state on the N+1 contract (query total calls and call info) + err = bc.Query("c3"+fmt.Sprintf("%d", maxCallDepth+1), `{"Name":"get_total_calls"}`, "", "null") + assert.NoErrorf(t, err, "failed to query") + err = bc.Query("c3"+fmt.Sprintf("%d", maxCallDepth+1), `{"Name":"get_call_info", "Args":["1"]}`, "", "null") + assert.NoErrorf(t, err, "failed to query") // Circle: contract 1 calls contract 2, contract 2 calls contract 3, contract 3 calls contract 1... @@ -269,9 +245,7 @@ func TestMaxCallDepth(t *testing.T) { err = bc.ConnectBlock( NewLuaTxDeploy("user", fmt.Sprintf("c4%d", i), 0, code2), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to deploy contract") } // build a list of contract IDs, used to call the first contract contracts = make([]string, 4) @@ -279,50 +253,36 @@ func TestMaxCallDepth(t *testing.T) { contracts[i-1] = StrToAddress(fmt.Sprintf("c4%d", i)) } contracts_str, err = json.Marshal(contracts) - if err != nil { - t.Error(err) - } - // call first contract - recursion depth 64 + require.NoErrorf(t, err, "failed to create contract list") + // call first contract - recursion depth = maxCallDepth err = bc.ConnectBlock( - NewLuaTxCall("user", "c4"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, 64]}`, string(contracts_str))), + NewLuaTxCall("user", "c4"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, %d]}`, string(contracts_str), maxCallDepth)), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") // check state on all the 4 contracts - // each contract should have (64 / 4) = 16 calls + // each contract should have (maxCallDepth / 4) calls for i := 1; i <= 4; i++ { - err = bc.Query(fmt.Sprintf("c4%d", i), `{"Name":"get_total_calls"}`, "", "16") - if err != nil { - t.Error(err) - } - for j := 1; j <= 16; j++ { + err = bc.Query(fmt.Sprintf("c4%d", i), `{"Name":"get_total_calls"}`, "", fmt.Sprintf("%d", maxCallDepth / 4)) + assert.NoErrorf(t, err, "failed to query") + for j := 1; j <= maxCallDepth / 4; j++ { err = bc.Query(fmt.Sprintf("c4%d", i), fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, j), "", fmt.Sprintf("%d", i+4*(j-1))) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } } - // call first contract - recursion depth 66 + // call first contract - recursion depth = maxCallDepth + 1 err = bc.ConnectBlock( - NewLuaTxCall("user", "c4"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, 66]}`, string(contracts_str))).Fail("exceeded the maximum call depth"), + NewLuaTxCall("user", "c4"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, %d]}`, string(contracts_str), maxCallDepth+1)).Fail("exceeded the maximum call depth"), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") // check state on all the 4 contracts - // each contract should have (64 / 4) = 16 calls + // each contract should have (maxCallDepth / 4) calls for i := 1; i <= 4; i++ { - err = bc.Query(fmt.Sprintf("c4%d", i), `{"Name":"get_total_calls"}`, "", "16") - if err != nil { - t.Error(err) - } - for j := 1; j <= 16; j++ { + err = bc.Query(fmt.Sprintf("c4%d", i), `{"Name":"get_total_calls"}`, "", fmt.Sprintf("%d", maxCallDepth / 4)) + assert.NoErrorf(t, err, "failed to query") + for j := 1; j <= maxCallDepth / 4; j++ { err = bc.Query(fmt.Sprintf("c4%d", i), fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, j), "", fmt.Sprintf("%d", i+4*(j-1))) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } } @@ -333,9 +293,7 @@ func TestMaxCallDepth(t *testing.T) { err = bc.ConnectBlock( NewLuaTxDeploy("user", fmt.Sprintf("c5%d", i), 0, code2), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to deploy contract") } // build a list of contract IDs, used to call the first contract contracts = make([]string, 2) @@ -343,62 +301,49 @@ func TestMaxCallDepth(t *testing.T) { contracts[i-1] = StrToAddress(fmt.Sprintf("c5%d", i)) } contracts_str, err = json.Marshal(contracts) - if err != nil { - t.Error(err) - } - // call first contract - recursion depth 64 + require.NoErrorf(t, err, "failed to create contract list") + // call first contract - recursion depth = maxCallDepth err = bc.ConnectBlock( - NewLuaTxCall("user", "c5"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, 64]}`, string(contracts_str))), + NewLuaTxCall("user", "c5"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, %d]}`, string(contracts_str), maxCallDepth)), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") // check state on all the 2 contracts - // each contract should have (64 / 2) = 32 calls + // each contract should have (maxCallDepth / 2) calls for i := 1; i <= 2; i++ { - err = bc.Query(fmt.Sprintf("c5%d", i), `{"Name":"get_total_calls"}`, "", "32") - if err != nil { - t.Error(err) - } - for j := 1; j <= 32; j++ { + err = bc.Query(fmt.Sprintf("c5%d", i), `{"Name":"get_total_calls"}`, "", fmt.Sprintf("%d", maxCallDepth / 2)) + assert.NoErrorf(t, err, "failed to query") + for j := 1; j <= maxCallDepth / 2; j++ { err = bc.Query(fmt.Sprintf("c5%d", i), fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, j), "", fmt.Sprintf("%d", i+2*(j-1))) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } } - // call first contract - recursion depth 66 + // call first contract - recursion depth = maxCallDepth + 1 err = bc.ConnectBlock( - NewLuaTxCall("user", "c5"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, 66]}`, string(contracts_str))).Fail("exceeded the maximum call depth"), + NewLuaTxCall("user", "c5"+fmt.Sprintf("%d", 1), 0, fmt.Sprintf(`{"Name":"call_me", "Args":[%s, 1, %d]}`, string(contracts_str), maxCallDepth+1)).Fail("exceeded the maximum call depth"), ) - if err != nil { - t.Error(err) - } + require.NoErrorf(t, err, "failed to call tx") // check state on all the 2 contracts - // each contract should have (64 / 2) = 32 calls + // each contract should have (maxCallDepth / 2) calls for i := 1; i <= 2; i++ { - err = bc.Query(fmt.Sprintf("c5%d", i), `{"Name":"get_total_calls"}`, "", "32") - if err != nil { - t.Error(err) - } - for j := 1; j <= 32; j++ { + err = bc.Query(fmt.Sprintf("c5%d", i), `{"Name":"get_total_calls"}`, "", fmt.Sprintf("%d", maxCallDepth / 2)) + assert.NoErrorf(t, err, "failed to query") + for j := 1; j <= maxCallDepth / 2; j++ { err = bc.Query(fmt.Sprintf("c5%d", i), fmt.Sprintf(`{"Name":"get_call_info", "Args":["%d"]}`, j), "", fmt.Sprintf("%d", i+2*(j-1))) - if err != nil { - t.Error(err) - } + assert.NoErrorf(t, err, "failed to query") } } - } } func TestContractSystem(t *testing.T) { code := readLuaCode(t, "contract_system.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -415,7 +360,7 @@ func TestContractSystem(t *testing.T) { exRv := fmt.Sprintf(`["%s","6FbDRScGruVdATaNWzD51xJkTfYCVwxSZDb7gzqCLzwf","AmhNNBNY7XFk4p5ym4CJf8nTcRTEHjWzAeXJfhP71244CjBCAQU3",%d,3,999]`, StrToAddress("user1"), bc.cBlock.Header.Timestamp/1e9) assert.Equal(t, exRv, receipt.GetRet(), "receipt ret error") - if version >= 4 { + if currentVersion >= 4 { // system.version() @@ -424,7 +369,7 @@ func TestContractSystem(t *testing.T) { require.NoErrorf(t, err, "failed to call tx") receipt = bc.GetReceipt(tx.Hash()) - expected := fmt.Sprintf(`%d`, version) + expected := fmt.Sprintf(`%d`, currentVersion) assert.Equal(t, expected, receipt.GetRet(), "receipt ret error") err = bc.Query("system", `{"Name":"get_version", "Args":[]}`, "", expected) @@ -475,15 +420,16 @@ func TestContractSystem(t *testing.T) { } - } } func TestContractHello(t *testing.T) { code := readLuaCode(t, "contract_hello.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) - require.NoErrorf(t, err, "failed to create test database") + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) + require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -499,7 +445,6 @@ func TestContractHello(t *testing.T) { receipt := bc.GetReceipt(tx.Hash()) assert.Equal(t, `"Hello World"`, receipt.GetRet(), "receipt ret error") - } } func TestContractSend(t *testing.T) { @@ -508,9 +453,11 @@ func TestContractSend(t *testing.T) { code3 := readLuaCode(t, "contract_send_3.lua") code4 := readLuaCode(t, "contract_send_4.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -520,40 +467,41 @@ func TestContractSend(t *testing.T) { NewLuaTxDeploy("user1", "test3", 0, code3), NewLuaTxDeploy("user1", "test4", 0, code4), ) - assert.NoErrorf(t, err, "failed to deploy contract") + require.NoErrorf(t, err, "failed to deploy contract") err = bc.ConnectBlock( NewLuaTxCall("user1", "test1", 0, fmt.Sprintf(`{"Name":"send", "Args":["%s"]}`, nameToAddress("test2"))), ) - assert.NoErrorf(t, err, "failed to call tx") + require.NoErrorf(t, err, "failed to call tx") state, err := bc.GetAccountState("test2") assert.Equalf(t, int64(2), state.GetBalanceBigInt().Int64(), "balance error") err = bc.ConnectBlock( - NewLuaTxCall("user1", "test1", 0, fmt.Sprintf(`{"Name":"send", "Args":["%s"]}`, nameToAddress("test3"))).Fail(`[Contract.LuaSendAmount] call err: not found function: default`), + NewLuaTxCall("user1", "test1", 0, fmt.Sprintf(`{"Name":"send", "Args":["%s"]}`, nameToAddress("test3"))).Fail(`call err: not found function: default`), ) - assert.NoErrorf(t, err, "failed to connect new block") + require.NoErrorf(t, err, "failed to connect new block") err = bc.ConnectBlock( - NewLuaTxCall("user1", "test1", 0, fmt.Sprintf(`{"Name":"send", "Args":["%s"]}`, nameToAddress("test4"))).Fail(`[Contract.LuaSendAmount] call err: 'default' is not payable`), + NewLuaTxCall("user1", "test1", 0, fmt.Sprintf(`{"Name":"send", "Args":["%s"]}`, nameToAddress("test4"))).Fail(`call err: 'default' is not payable`), ) - assert.NoErrorf(t, err, "failed to connect new block") + require.NoErrorf(t, err, "failed to connect new block") err = bc.ConnectBlock( NewLuaTxCall("user1", "test1", 0, fmt.Sprintf(`{"Name":"send", "Args":["%s"]}`, nameToAddress("user1"))), ) - assert.NoErrorf(t, err, "failed to connect new block") + require.NoErrorf(t, err, "failed to connect new block") - } } func TestContractQuery(t *testing.T) { code := readLuaCode(t, "contract_query.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -575,7 +523,6 @@ func TestContractQuery(t *testing.T) { err = bc.Query("query", `{"Name":"query", "Args":["key1"]}`, "", "1") require.NoErrorf(t, err, "failed to query") - } } func TestContractCall(t *testing.T) { @@ -583,9 +530,11 @@ func TestContractCall(t *testing.T) { code2 := readLuaCode(t, "contract_call_2.lua") code3 := readLuaCode(t, "contract_call_3.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -735,15 +684,16 @@ func TestContractCall(t *testing.T) { err = bc.Query("caller", `{"Name":"get_call_info", "Args":["AmhJ2JWVSDeXxYrMRtH38hjnGDLVkLJCLD1XCTGZSjoQV2xCQUEg","get_call_info"]}`, "", expected) require.NoErrorf(t, err, "failed to query") - } } func TestContractCallSelf(t *testing.T) { code := readLuaCode(t, "contract_call_self.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -766,16 +716,17 @@ func TestContractCallSelf(t *testing.T) { receipt = bc.GetReceipt(tx.Hash()) require.Equalf(t, `5`, receipt.GetRet(), "contract call ret error") - } } func TestContractPingPongCall(t *testing.T) { code1 := readLuaCode(t, "contract_pingpongcall_1.lua") code2 := readLuaCode(t, "contract_pingpongcall_2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -800,15 +751,16 @@ func TestContractPingPongCall(t *testing.T) { err = bc.Query("B", `{"Name":"get"}`, "", `"called"`) require.NoErrorf(t, err, "failed to query") - } } func TestRollback(t *testing.T) { code := readLuaCode(t, "rollback.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -841,7 +793,6 @@ func TestRollback(t *testing.T) { err = bc.Query("query", `{"Name":"query", "Args":["key1"]}`, "", "2") require.NoErrorf(t, err, "failed to query") - } } func TestAbi(t *testing.T) { @@ -849,9 +800,11 @@ func TestAbi(t *testing.T) { codeEmpty := readLuaCode(t, "abi_empty.lua") codeLocalFunc := readLuaCode(t, "abi_localfunc.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "a", 0, codeNoAbi)) @@ -866,15 +819,16 @@ func TestAbi(t *testing.T) { require.Errorf(t, err, fmt.Sprintf("expected err : %s, buf got nil", "global function expected")) require.Containsf(t, err.Error(), "global function expected", "not contains error message") - } } func TestGetABI(t *testing.T) { code := readLuaCode(t, "getabi.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "hello", 0, code)) @@ -887,15 +841,16 @@ func TestGetABI(t *testing.T) { require.NoErrorf(t, err, "failed to marshal abi") require.Equalf(t, `{"version":"0.2","language":"lua","functions":[{"name":"hello","arguments":[{"name":"say"}]}],"state_variables":[{"name":"Say","type":"value"}]}`, string(jsonAbi), "not equal abi") - } } func TestPayable(t *testing.T) { code := readLuaCode(t, "payable.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -920,15 +875,16 @@ func TestPayable(t *testing.T) { err = bc.Query("payable", `{"Name":"load"}`, "", `"payed"`) require.NoErrorf(t, err, "failed to query") - } } func TestDefault(t *testing.T) { code := readLuaCode(t, "default.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -950,16 +906,17 @@ func TestDefault(t *testing.T) { err = bc.Query("default", `{"Name":"a"}`, "not found function: a", "") require.NoErrorf(t, err, "failed to query") - } } func TestReturn(t *testing.T) { code := readLuaCode(t, "return_1.lua") code2 := readLuaCode(t, "return_2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -981,15 +938,16 @@ func TestReturn(t *testing.T) { err = bc.Query("foo", `{"Name":"foo2", "Args":["foo314"]}`, "", `"foo314"`) require.NoErrorf(t, err, "failed to query") - } } func TestReturnUData(t *testing.T) { code := readLuaCode(t, "return_udata.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1001,15 +959,16 @@ func TestReturnUData(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "rs-return", 0, `{"Name": "test_die", "Args":[]}`).Fail(`unsupport type: userdata`)) require.NoErrorf(t, err, "failed to connect new block") - } } func TestEvent(t *testing.T) { code := readLuaCode(t, "event.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1021,16 +980,16 @@ func TestEvent(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "event", 0, `{"Name": "test_ev", "Args":[]}`)) require.NoErrorf(t, err, "failed to connect new block") - } - } func TestView(t *testing.T) { code := readLuaCode(t, "view.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1057,18 +1016,21 @@ func TestView(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "view", 0, `{"Name": "k3", "Args":[]}`)) require.NoErrorf(t, err, "failed to connect new block") - err = bc.ConnectBlock(NewLuaTxCall("user1", "view", 0, `{"Name": "sqltest", "Args":[]}`).Fail("not permitted in view function")) - require.NoErrorf(t, err, "failed to connect new block") + if bc.PubNet == false { + err = bc.ConnectBlock(NewLuaTxCall("user1", "view", 0, `{"Name": "sqltest", "Args":[]}`).Fail("not permitted in view function")) + require.NoErrorf(t, err, "failed to connect new block") + } - } } func TestDeploy(t *testing.T) { code := readLuaCode(t, "deploy.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1120,15 +1082,16 @@ func TestDeploy(t *testing.T) { receipt = bc.GetReceipt(tx.Hash()) assert.Containsf(t, receipt.GetRet(), "cannot find contract", "contract Call ret error") - } } func TestDeploy2(t *testing.T) { code := readLuaCode(t, "deploy2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() oneAergo := types.NewAmount(1, types.Aergo) @@ -1140,20 +1103,20 @@ func TestDeploy2(t *testing.T) { ) require.NoErrorf(t, err, "failed to connect new block") - tx := NewLuaTxCall("user1", "deploy", 0, `{"Name":"hello"}`).Fail(`not permitted state referencing at global scope`) + tx := NewLuaTxCall("user1", "deploy", 0, `{"Name":"hello"}`).Fail(`state referencing not permitted at global scope`) err = bc.ConnectBlock(tx) require.NoErrorf(t, err, "failed to connect new block") - } - } func TestNDeploy(t *testing.T) { code := readLuaCode(t, "deployn.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1163,15 +1126,16 @@ func TestNDeploy(t *testing.T) { ) require.NoErrorf(t, err, "failed to connect new block") - } } -func xestInfiniteLoop(t *testing.T) { +func TestInfiniteLoopOnPrivateNet(t *testing.T) { code := readLuaCode(t, "infiniteloop.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetTimeout(50), SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), SetTimeout(750)) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1180,32 +1144,40 @@ func xestInfiniteLoop(t *testing.T) { ) require.NoErrorf(t, err, "failed to connect new block") + // private nets use a limit of instruction count instead of timeout errTimeout := "exceeded the maximum instruction count" - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infiniteLoop"}`)) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infinite_loop"}`)) require.Errorf(t, err, "expected: %v", errTimeout) require.Containsf(t, err.Error(), errTimeout, "not contain timeout error") - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"catch"}`)) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"catch_loop"}`)) require.Errorf(t, err, "expected: %v", errTimeout) require.Containsf(t, err.Error(), errTimeout, "not contain timeout error") - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"contract_catch"}`)) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"contract_catch_loop"}`)) require.Errorf(t, err, "expected: %v", errTimeout) require.Containsf(t, err.Error(), errTimeout, "not contain timeout error") - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infiniteCall"}`).Fail("stack overflow")) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infinite_call"}`).Fail("stack overflow")) + require.NoErrorf(t, err, "failed to connect new block") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"catch_call"}`)) + require.NoErrorf(t, err, "failed to connect new block") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"contract_catch_call"}`)) require.NoErrorf(t, err, "failed to connect new block") - } } func TestInfiniteLoopOnPubNet(t *testing.T) { code := readLuaCode(t, "infiniteloop.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetTimeout(50), SetPubNet(), SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet(), SetTimeout(50)) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1216,30 +1188,37 @@ func TestInfiniteLoopOnPubNet(t *testing.T) { errTimeout := contract.VmTimeoutError{} - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infiniteLoop"}`)) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infinite_loop"}`)) require.Errorf(t, err, "expected: %v", errTimeout) require.Containsf(t, err.Error(), errTimeout.Error(), "not contain timeout error") - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"catch"}`)) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"catch_loop"}`)) require.Errorf(t, err, "expected: %v", errTimeout) require.Containsf(t, err.Error(), errTimeout.Error(), "not contain timeout error") - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"contract_catch"}`)) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"contract_catch_loop"}`)) require.Errorf(t, err, "expected: %v", errTimeout) require.Containsf(t, err.Error(), errTimeout.Error(), "not contain timeout error") - err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infiniteCall"}`).Fail("stack overflow")) + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"infinite_call"}`).Fail("stack overflow")) + require.NoErrorf(t, err, "failed to connect new block") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"catch_call"}`)) + require.NoErrorf(t, err, "failed to connect new block") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "loop", 0, `{"Name":"contract_catch_call"}`)) require.NoErrorf(t, err, "failed to connect new block") - } } func TestUpdateSize(t *testing.T) { code := readLuaCode(t, "updatesize.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1251,7 +1230,6 @@ func TestUpdateSize(t *testing.T) { require.Errorf(t, err, "expected: %s", errMsg) require.Containsf(t, err.Error(), errMsg, "error message not same as expected") - } } func TestTimeoutCnt(t *testing.T) { @@ -1260,9 +1238,11 @@ func TestTimeoutCnt(t *testing.T) { code := readLuaCode(t, "timeout_1.lua") code2 := readLuaCode(t, "timeout_2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetTimeout(500), SetPubNet(), SetHardForkVersion(version)) // timeout 500 milliseconds + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet(), SetTimeout(500)) // timeout 500 milliseconds require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1283,15 +1263,16 @@ func TestTimeoutCnt(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "timeout-cnt2", 0, `{"Name": "a"}`).Fail("contract timeout")) require.NoErrorf(t, err, "failed to call tx") - } } func TestSnapshot(t *testing.T) { code := readLuaCode(t, "snapshot.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1301,13 +1282,13 @@ func TestSnapshot(t *testing.T) { require.NoErrorf(t, err, "failed to deploy contract") err = bc.ConnectBlock(NewLuaTxCall("user1", "snap", 0, `{"Name": "inc", "Args":[]}`)) - assert.NoErrorf(t, err, "failed to call contract") + require.NoErrorf(t, err, "failed to call contract") err = bc.ConnectBlock(NewLuaTxCall("user1", "snap", 0, `{"Name": "inc", "Args":[]}`)) - assert.NoErrorf(t, err, "failed to call contract") + require.NoErrorf(t, err, "failed to call contract") err = bc.ConnectBlock(NewLuaTxCall("user1", "snap", 0, `{"Name": "inc", "Args":[]}`)) - assert.NoErrorf(t, err, "failed to call contract") + require.NoErrorf(t, err, "failed to call contract") err = bc.Query("snap", `{"Name":"query"}`, "", "[3,3,3,3]") assert.NoErrorf(t, err, "failed to query") @@ -1321,15 +1302,16 @@ func TestSnapshot(t *testing.T) { err = bc.Query("snap", `{"Name":"query2", "Args":[]}`, "invalid argument at getsnap, need (state.array, index, blockheight)", "") assert.NoErrorf(t, err, "failed to query") - } } func TestKvstore(t *testing.T) { code := readLuaCode(t, "kvstore.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1372,16 +1354,17 @@ func TestKvstore(t *testing.T) { err = bc.Query("map", `{"Name":"getname"}`, "", `"eve2adam"`) assert.NoErrorf(t, err, "failed to query") - } } // sql tests func TestSqlConstrains(t *testing.T) { code := readLuaCode(t, "sql_constrains.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1396,15 +1379,16 @@ func TestSqlConstrains(t *testing.T) { ) require.NoErrorf(t, err, "failed to call contract") - } } func TestSqlAutoincrement(t *testing.T) { code := readLuaCode(t, "sql_autoincrement.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1418,15 +1402,16 @@ func TestSqlAutoincrement(t *testing.T) { err = bc.ConnectBlock(tx) require.NoErrorf(t, err, "failed to call tx") - } } func TestSqlOnConflict(t *testing.T) { code := readLuaCode(t, "sql_onconflict.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1480,7 +1465,7 @@ func TestSqlOnConflict(t *testing.T) { require.NoErrorf(t, err, "failed to call tx") var expected string - if version >= 4 { + if currentVersion >= 4 { // pcall reverts the changes expected = `[1,2,3,4,5,6]` } else { @@ -1498,15 +1483,16 @@ func TestSqlOnConflict(t *testing.T) { err = bc.Query("on_conflict", `{"name":"get"}`, "", expected) require.NoErrorf(t, err, "failed to query") - } } func TestSqlDupCol(t *testing.T) { code := readLuaCode(t, "sql_dupcol.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1518,15 +1504,16 @@ func TestSqlDupCol(t *testing.T) { err = bc.Query("dup_col", `{"name":"get"}`, `too many duplicate column name "1+1", max: 5`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmSimple(t *testing.T) { code := readLuaCode(t, "sql_vm_simple.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1566,15 +1553,16 @@ func TestSqlVmSimple(t *testing.T) { err = bc.Query("simple-query", `{"Name": "count", "Args":[]}`, "not found contract", "") require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmFail(t *testing.T) { code := readLuaCode(t, "sql_vm_fail.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1606,15 +1594,16 @@ func TestSqlVmFail(t *testing.T) { err = bc.Query("fail", `{"Name":"get"}`, "", "7") require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmPubNet(t *testing.T) { code := readLuaCode(t, "sql_vm_pubnet.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetPubNet(), SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1626,15 +1615,16 @@ func TestSqlVmPubNet(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "simple-query", 0, `{"Name": "createAndInsert", "Args":[]}`).Fail(`attempt to index global 'db'`)) require.NoErrorf(t, err, "failed to call tx") - } } func TestSqlVmDateTime(t *testing.T) { code := readLuaCode(t, "sql_vm_datetime.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1644,6 +1634,9 @@ func TestSqlVmDateTime(t *testing.T) { ) require.NoErrorf(t, err, "failed to deploy") + err = bc.Query("datetime", `{"Name":"get"}`, "", `[{"bool":1,"date":"1970-01-01 02:46:40"},{"bool":0,"date":"2004-11-23"}]`) + require.NoErrorf(t, err, "failed to query") + err = bc.ConnectBlock(NewLuaTxCall("user1", "datetime", 0, `{"Name":"nowNull"}`)) require.NoErrorf(t, err, "failed to call tx") @@ -1653,15 +1646,16 @@ func TestSqlVmDateTime(t *testing.T) { err = bc.Query("datetime", `{"Name":"get"}`, "", `[{"bool":0},{"bool":1},{"bool":1,"date":"1970-01-01 02:46:40"},{"bool":0,"date":"2004-11-23"}]`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmCustomer(t *testing.T) { code := readLuaCode(t, "sql_vm_customer.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1695,15 +1689,16 @@ func TestSqlVmCustomer(t *testing.T) { err = bc.Query("customer", `{"Name":"query", "Args":["id2"]}`, "", `{}`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmDataType(t *testing.T) { code := readLuaCode(t, "sql_vm_datatype.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1729,15 +1724,16 @@ func TestSqlVmDataType(t *testing.T) { err = bc.Query("datatype", `{"Name":"queryGroupByBlockheight1"}`, "", `[{"avg_float1":3.14,"blockheight1":2,"count1":3,"sum_int1":3},{"avg_float1":3.14,"blockheight1":3,"count1":1,"sum_int1":1}]`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmFunction(t *testing.T) { code := readLuaCode(t, "sql_vm_function.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1755,15 +1751,16 @@ func TestSqlVmFunction(t *testing.T) { err = bc.Query("fns", `{"Name":"typeof_func"}`, "", `["integer","text","real","null"]`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmBook(t *testing.T) { code := readLuaCode(t, "sql_vm_book.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1782,15 +1779,16 @@ func TestSqlVmBook(t *testing.T) { err = bc.Query("book", `{"Name":"viewCopyBook"}`, "", `[100,"value=1"]`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmDateformat(t *testing.T) { code := readLuaCode(t, "sql_vm_dateformat.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1803,15 +1801,16 @@ func TestSqlVmDateformat(t *testing.T) { err = bc.Query("data_format", `{"Name":"get"}`, "", `[["2004-10-24","2004-10-24 11:11:11","20041024111111"],["2018-05-28","2018-05-28 10:45:38","20180528104538"]]`) require.NoErrorf(t, err, "failed to query") - } } func TestSqlVmRecursiveData(t *testing.T) { code := readLuaCode(t, "sql_vm_recursivedata.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() tx := NewLuaTxCall("user1", "r", 0, `{"Name":"r"}`) @@ -1823,15 +1822,16 @@ func TestSqlVmRecursiveData(t *testing.T) { require.Errorf(t, err, "expect err") require.Equalf(t, "nested table error", err.Error(), "expect err") - } } func TestSqlJdbc(t *testing.T) { code := readLuaCode(t, "sql_jdbc.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -1860,73 +1860,126 @@ func TestSqlJdbc(t *testing.T) { `{"colcnt":3,"colmetas":{"colcnt":3,"decltypes":["int","int","text"],"names":["a","b","c"]},"data":[[1,{},"2"],[2,2,"3"],[3,2,"3"],[4,2,"3"],[5,2,"3"],[6,2,"3"],[7,2,"3"]],"rowcnt":7,"snap":"3"}`) require.NoErrorf(t, err, "failed to query") - } } func TestTypeMaxString(t *testing.T) { code := readLuaCode(t, "type_maxstring.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "oom", 0, code)) require.NoErrorf(t, err, "failed to deploy") - errMsg := "not enough memory" - err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom"}`).Fail(errMsg)) - require.NoErrorf(t, err, "failed to call tx") + //errMsg1 := "string length overflow" + errMsg2 := "table overflow" + errMsg3 := "not enough memory" - err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"p"}`).Fail(errMsg)) - require.NoErrorf(t, err, "failed to call tx") + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_string"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") - err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"cp"}`).Fail(errMsg)) - require.NoErrorf(t, err, "failed to call tx") + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_string"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_string"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_table1"}`).Fail(errMsg2)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_table1"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_table1"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_table2"}`).Fail(errMsg2)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_table2"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_table2"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_global"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_global"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_global"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") - } } func TestTypeMaxStringOnPubNet(t *testing.T) { code := readLuaCode(t, "type_maxstring.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version), SetPubNet()) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "oom", 0, code)) require.NoErrorf(t, err, "failed to deploy") - errMsg := "string length overflow" - errMsg1 := "not enough memory" - var travis bool - if os.Getenv("TRAVIS") == "true" { - travis = true - } - err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom"}`)) - require.Errorf(t, err, "expected: %s", errMsg) - if !strings.Contains(err.Error(), errMsg) && !strings.Contains(err.Error(), errMsg1) { - t.Error(err) - } - err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"p"}`)) - if err != nil && (!travis || !strings.Contains(err.Error(), errMsg1)) { - t.Error(err) - } - err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"cp"}`)) - if err != nil && (!travis || !strings.Contains(err.Error(), errMsg1)) { - t.Error(err) - } + //errMsg1 := "string length overflow" + errMsg2 := "table overflow" + errMsg3 := "not enough memory" + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_string"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_string"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_string"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_table1"}`).Fail(errMsg2)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_table1"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_table1"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_table2"}`).Fail(errMsg2)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_table2"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_table2"}`)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"oom_global"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"pcall_global"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") + + err = bc.ConnectBlock(NewLuaTxCall("user1", "oom", 0, `{"Name":"contract_pcall_global"}`).Fail(errMsg3)) + assert.NoErrorf(t, err, "failed to call tx") - } } func TestTypeNsec(t *testing.T) { code := readLuaCode(t, "type_nsec.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "nsec", 0, code)) @@ -1935,15 +1988,16 @@ func TestTypeNsec(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "nsec", 0, `{"Name": "test_nsec"}`).Fail(`attempt to call global 'nsec' (a nil value)`)) require.NoErrorf(t, err, "failed to call tx") - } } func TestTypeUtf(t *testing.T) { code := readLuaCode(t, "type_utf.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "utf", 0, code)) @@ -1958,16 +2012,17 @@ func TestTypeUtf(t *testing.T) { err = bc.Query("utf", `{"Name":"query3"}`, "bignum not allowed negative value", "") assert.NoErrorf(t, err, "failed to query") - } } func TestTypeDupVar(t *testing.T) { code := readLuaCode(t, "type_dupvar_1.lua") code2 := readLuaCode(t, "type_dupvar_2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -1984,15 +2039,16 @@ func TestTypeDupVar(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "dupVar1", 0, `{"Name": "Work"}`).Fail("duplicated variable: 'Var1'")) require.NoErrorf(t, err, "failed to call tx") - } } func TestTypeByteKey(t *testing.T) { code := readLuaCode(t, "type_bytekey.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "bk", 0, code)) @@ -2004,7 +2060,6 @@ func TestTypeByteKey(t *testing.T) { err = bc.Query("bk", `{"Name":"getcre"}`, "", fmt.Sprintf(`"%s"`, nameToAddress("user1"))) require.NoErrorf(t, err, "failed to query") - } } func TestTypeArray(t *testing.T) { @@ -2012,9 +2067,11 @@ func TestTypeArray(t *testing.T) { code2 := readLuaCode(t, "type_array_overflow.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "array", 0, code)) @@ -2055,16 +2112,17 @@ func TestTypeArray(t *testing.T) { require.Errorf(t, err, "expect no error") require.Containsf(t, err.Error(), errMsg, "err not match") - } } func TestTypeMultiArray(t *testing.T) { code := readLuaCode(t, "type_multiarray_1.lua") code2 := readLuaCode(t, "type_multiarray_2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "ma", 0, code)) @@ -2107,15 +2165,16 @@ func TestTypeMultiArray(t *testing.T) { err = bc.Query("ma", `{"Name":"query", "Args":[]}`, "", `["A","B","C","D","A","B","v3"]`) require.NoErrorf(t, err, "failed to query") - } } func TestTypeArrayArg(t *testing.T) { code := readLuaCode(t, "type_arrayarg.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "a", 0, code)) @@ -2149,15 +2208,16 @@ func TestTypeArrayArg(t *testing.T) { ) require.NoErrorf(t, err, "failed to query") - } } func TestTypeMapKey(t *testing.T) { code := readLuaCode(t, "type_mapkey.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "a", 0, code)) @@ -2210,15 +2270,16 @@ func TestTypeMapKey(t *testing.T) { err = bc.Query("x", `{"Name":"getCount", "Args":["third"]}`, "", "30") require.NoErrorf(t, err, "failed to query") - } } func TestTypeStateVarFieldUpdate(t *testing.T) { code := readLuaCode(t, "type_statevarfieldupdate.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "c", 0, code)) @@ -2236,15 +2297,16 @@ func TestTypeStateVarFieldUpdate(t *testing.T) { err = bc.Query("c", `{"Name":"GetPerson"}`, "", `{"address":"blahblah...","age":10,"name":"user2"}`) require.NoErrorf(t, err, "failed to query") - } } func TestTypeDatetime(t *testing.T) { code := readLuaCode(t, "type_datetime.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "datetime", 0, code)) @@ -2352,16 +2414,17 @@ func TestTypeDatetime(t *testing.T) { err = bc.Query("datetime", `{"Name": "Difftime"}`, "", `[25500,"07:05:00"]`) require.NoErrorf(t, err, "failed to query") - } } func TestTypeDynamicArray(t *testing.T) { code := readLuaCode(t, "type_dynamicarray_zerolen.lua") code2 := readLuaCode(t, "type_dynamicarray.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo)) @@ -2408,15 +2471,16 @@ func TestTypeDynamicArray(t *testing.T) { err = bc.Query("dArr", `{"Name": "Get", "Args": [3]}`, "", "50") require.NoErrorf(t, err, "failed to query") - } } func TestTypeCrypto(t *testing.T) { code := readLuaCode(t, "type_crypto.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "crypto", 0, code)) @@ -2440,16 +2504,17 @@ func TestTypeCrypto(t *testing.T) { err = bc.Query("crypto", `{"Name": "keccak256", "Args" : ["0x616572676F"]}`, "", `"0xe98bb03ab37161f8bbfe131f711dcccf3002a9cd9ec31bbd52edf181f7ab09a0"`) require.NoErrorf(t, err, "failed to query") - } } func TestTypeBignum(t *testing.T) { bignum := readLuaCode(t, "type_bignum.lua") callee := readLuaCode(t, "type_bignum_callee.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2498,14 +2563,20 @@ func TestTypeBignum(t *testing.T) { err = bc.Query("bigNum", `{"Name":"byteBignum"}`, "", `{"_bignum":"177"}`) require.NoErrorf(t, err, "failed to query") - } } func TestBignumValues(t *testing.T) { code := readLuaCode(t, "bignum_values.lua") - bc, err := LoadDummyChain(SetHardForkVersion(2)) + if currentVersion <= 2 { + // hardfork 2 + // process octal, hex, binary + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2513,11 +2584,7 @@ func TestBignumValues(t *testing.T) { NewLuaTxDeploy("user1", "contract1", 0, code), ) require.NoErrorf(t, err, "failed to deploy") - - // hardfork 2 - - // process octal, hex, binary - + err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":["0"]}`, "", `"0"`) require.NoErrorf(t, err, "failed to query") err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":["9"]}`, "", `"9"`) @@ -2542,11 +2609,22 @@ func TestBignumValues(t *testing.T) { err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":[{"_bignum":"0b1010101010101"}]}`, "", `"5461"`) require.NoErrorf(t, err, "failed to query") + } else if currentVersion == 3 { + // hardfork 3 + // block octal, hex and binary - // hardfork 3 - bc.HardforkVersion = 3 + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) + require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } + defer bc.Release() - // block octal, hex and binary + err = bc.ConnectBlock( + NewLuaTxAccount("user1", 1, types.Aergo), + NewLuaTxDeploy("user1", "contract1", 0, code), + ) + require.NoErrorf(t, err, "failed to deploy") tx := NewLuaTxCall("user1", "contract1", 0, `{"Name":"parse_bignum", "Args":["01234567"]}`) err = bc.ConnectBlock(tx) @@ -2554,6 +2632,12 @@ func TestBignumValues(t *testing.T) { receipt := bc.GetReceipt(tx.Hash()) assert.Equalf(t, `"1234567"`, receipt.GetRet(), "contract Call ret error") + tx = NewLuaTxCall("user1", "contract1", 0, `{"Name":"parse_bignum", "Args":["01234567"]}`) + err = bc.ConnectBlock(tx) + require.NoErrorf(t, err, "failed to call tx") + receipt = bc.GetReceipt(tx.Hash()) + assert.Equalf(t, `"1234567"`, receipt.GetRet(), "contract Call ret error") + err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":["0"]}`, "", `"0"`) require.NoErrorf(t, err, "failed to query") err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":["9"]}`, "", `"9"`) @@ -2573,17 +2657,19 @@ func TestBignumValues(t *testing.T) { require.NoErrorf(t, err, "failed to query") err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":[{"_bignum":"01234567"}]}`, "", `"1234567"`) require.NoErrorf(t, err, "failed to query") - err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":[{"_bignum":"0x123456789abcdef"}]}`, "bignum invalid number string", `""`) + err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":[{"_bignum":"0x123456789abcdef"}]}`, "invalid arguments", `""`) require.NoErrorf(t, err, "failed to query") - err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":[{"_bignum":"0b1010101010101"}]}`, "bignum invalid number string", `""`) + err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":[{"_bignum":"0b1010101010101"}]}`, "invalid arguments", `""`) require.NoErrorf(t, err, "failed to query") + } else { + // hardfork 4 and after - // hardfork 4 and after - - for version := int32(4); version <= max_version; version++ { - bc, err = LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2594,10 +2680,10 @@ func TestBignumValues(t *testing.T) { // process hex, binary. block octal - tx = NewLuaTxCall("user1", "contract1", 0, `{"Name":"parse_bignum", "Args":["01234567"]}`) + tx := NewLuaTxCall("user1", "contract1", 0, `{"Name":"parse_bignum", "Args":["01234567"]}`) err = bc.ConnectBlock(tx) require.NoErrorf(t, err, "failed to call tx") - receipt = bc.GetReceipt(tx.Hash()) + receipt := bc.GetReceipt(tx.Hash()) assert.Equalf(t, `"1234567"`, receipt.GetRet(), "contract Call ret error") err = bc.Query("contract1", `{"Name":"parse_bignum", "Args":["0"]}`, "", `"0"`) @@ -2639,9 +2725,11 @@ func TestTypeRandom(t *testing.T) { code1 := readLuaCode(t, "type_random.lua") code2 := readLuaCode(t, "type_random_caller.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2691,15 +2779,16 @@ func TestTypeRandom(t *testing.T) { receipt = bc.GetReceipt(tx.Hash()) assert.Equalf(t, `false`, receipt.GetRet(), "random numbers are the same on the same transaction") - } } func TestTypeSparseTable(t *testing.T) { code := readLuaCode(t, "type_sparsetable.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() tx := NewLuaTxCall("user1", "r", 0, `{"Name":"r"}`) @@ -2709,15 +2798,16 @@ func TestTypeSparseTable(t *testing.T) { receipt := bc.GetReceipt(tx.Hash()) require.Equalf(t, `1`, receipt.GetRet(), "contract Call ret error") - } } func TestTypeJson(t *testing.T) { code := readLuaCode(t, "type_json.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "json", 0, code)) @@ -2777,16 +2867,17 @@ func TestTypeJson(t *testing.T) { err = bc.ConnectBlock(NewLuaTxCall("user1", "json", 0, `{"Name":"set", "Args":["{\"key1\":[1,2,3], \"key1\":5}}"]}`).Fail("not proper json format")) require.NoErrorf(t, err, "failed to call tx") - } } // feature tests func TestFeatureVote(t *testing.T) { code := readLuaCode(t, "feature_vote.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2841,15 +2932,16 @@ func TestFeatureVote(t *testing.T) { err = bc.Query("vote", `{"Name":"getCandidates"}`, "", `[{"count":"2","id":0,"name":"candidate1"},{"count":"0","id":1,"name":"candidate2"},{"count":"0","id":2,"name":"candidate3"}]`) require.NoErrorf(t, err, "failed to query") - } } func TestFeatureGovernance(t *testing.T) { code := readLuaCode(t, "feature_governance.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 40000, types.Aergo), NewLuaTxDeploy("user1", "gov", 0, code)) @@ -2890,7 +2982,6 @@ func TestFeatureGovernance(t *testing.T) { require.Equalf(t, oldstaking.Amount, newstaking.Amount, "pcall error, staking amount should be same") require.Equalf(t, oldgov.GetBalance(), newgov.GetBalance(), "pcall error, gov balance should be same") - } } func TestFeaturePcallRollback(t *testing.T) { @@ -2898,9 +2989,11 @@ func TestFeaturePcallRollback(t *testing.T) { code2 := readLuaCode(t, "feature_pcall_rollback_2.lua") code3 := readLuaCode(t, "feature_pcall_rollback_3.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2919,14 +3012,18 @@ func TestFeaturePcallRollback(t *testing.T) { ) require.NoErrorf(t, err, "failed to deploy") - err = bc.ConnectBlock(NewLuaTxCall("user1", "caller", 0, `{"Name":"sql", "Args":[]}`)) - require.NoErrorf(t, err, "failed to call tx") + if bc.PubNet == false { + err = bc.ConnectBlock(NewLuaTxCall("user1", "caller", 0, `{"Name":"sql", "Args":[]}`)) + require.NoErrorf(t, err, "failed to call tx") + } err = bc.Query("caller", `{"Name":"get", "Args":[]}`, "", "2") require.NoErrorf(t, err, "failed to query") - err = bc.Query("caller", `{"Name":"sqlget", "Args":[]}`, "", "2") - require.NoErrorf(t, err, "failed to query") + if bc.PubNet == false { + err = bc.Query("caller", `{"Name":"sqlget", "Args":[]}`, "", "2") + require.NoErrorf(t, err, "failed to query") + } tx := NewLuaTxCall("user1", "caller", 0, `{"Name":"getOrigin", "Args":[]}`) err = bc.ConnectBlock(tx) @@ -2937,8 +3034,11 @@ func TestFeaturePcallRollback(t *testing.T) { // create new dummy chain - bc, err = LoadDummyChain(SetHardForkVersion(version)) + bc, err = LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -2973,15 +3073,16 @@ func TestFeaturePcallRollback(t *testing.T) { require.NoErrorf(t, err, "failed to get account state") assert.Equal(t, int64(3), state.GetBalanceBigInt().Int64(), "balance error") - } } func TestFeaturePcallNested(t *testing.T) { code := readLuaCode(t, "feature_pcall_nested.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -3004,18 +3105,16 @@ func TestFeaturePcallNested(t *testing.T) { require.NoErrorf(t, err, "failed to get account state") assert.Equal(t, int64(types.Aergo), state.GetBalanceBigInt().Int64(), "balance error") - } } // test rollback of state variable and balance func TestPcallStateRollback1(t *testing.T) { resolver := readLuaCode(t, "resolver.lua") - for version := min_version; version <= max_version; version++ { files := make([]string, 0) files = append(files, "feature_pcall_rollback_4a.lua") // contract.pcall - if version >= 4 { + if currentVersion >= 4 { files = append(files, "feature_pcall_rollback_4b.lua") // pcall files = append(files, "feature_pcall_rollback_4c.lua") // xpcall } @@ -3025,8 +3124,11 @@ func TestPcallStateRollback1(t *testing.T) { code := readLuaCode(t, file) - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() // deploy and setup the name resolver @@ -3041,9 +3143,9 @@ func TestPcallStateRollback1(t *testing.T) { // deploy the contracts err = bc.ConnectBlock( - NewLuaTxDeploy("user", "A", 3, code).Constructor(fmt.Sprintf(`["%s","A"]`, nameToAddress("resolver"))), - NewLuaTxDeploy("user", "B", 0, code).Constructor(fmt.Sprintf(`["%s","B"]`, nameToAddress("resolver"))), - NewLuaTxDeploy("user", "C", 0, code).Constructor(fmt.Sprintf(`["%s","C"]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "A", 3, code).Constructor(fmt.Sprintf(`["%s","A",false]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "B", 0, code).Constructor(fmt.Sprintf(`["%s","B",false]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "C", 0, code).Constructor(fmt.Sprintf(`["%s","C",false]`, nameToAddress("resolver"))), ) require.NoErrorf(t, err, "failed to deploy the contracts") @@ -3450,7 +3552,6 @@ func TestPcallStateRollback1(t *testing.T) { map[string]int64{"A": 3, "B": 0}) } - } } // test rollback of state variable and balance - send separate from call @@ -3458,10 +3559,9 @@ func TestPcallStateRollback2(t *testing.T) { t.Skip("disabled until bug with test is fixed") resolver := readLuaCode(t, "resolver.lua") - for version := min_version; version <= max_version; version++ { files := make([]string, 0) files = append(files, "feature_pcall_rollback_4a.lua") // contract.pcall - if version >= 4 { + if currentVersion >= 4 { files = append(files, "feature_pcall_rollback_4b.lua") // pcall files = append(files, "feature_pcall_rollback_4c.lua") // xpcall } @@ -3471,8 +3571,11 @@ func TestPcallStateRollback2(t *testing.T) { code := readLuaCode(t, file) - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() // deploy and setup the name resolver @@ -3489,9 +3592,9 @@ func TestPcallStateRollback2(t *testing.T) { // deploy the contracts err = bc.ConnectBlock( - NewLuaTxDeploy("user", "A", 3, code).Constructor(fmt.Sprintf(`["%s","A"]`, nameToAddress("resolver"))), - NewLuaTxDeploy("user", "B", 0, code).Constructor(fmt.Sprintf(`["%s","B"]`, nameToAddress("resolver"))), - NewLuaTxDeploy("user", "C", 0, code).Constructor(fmt.Sprintf(`["%s","C"]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "A", 3, code).Constructor(fmt.Sprintf(`["%s","A",false]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "B", 0, code).Constructor(fmt.Sprintf(`["%s","B",false]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "C", 0, code).Constructor(fmt.Sprintf(`["%s","C",false]`, nameToAddress("resolver"))), ) require.NoErrorf(t, err, "failed to deploy the contracts") @@ -3992,7 +4095,6 @@ func TestPcallStateRollback2(t *testing.T) { map[string]int64{"A": 3, "B": 0}) } - } } // test rollback of db @@ -4000,10 +4102,9 @@ func TestPcallStateRollback3(t *testing.T) { t.Skip("disabled until bug with test is fixed") resolver := readLuaCode(t, "resolver.lua") - for version := min_version; version <= max_version; version++ { files := make([]string, 0) files = append(files, "feature_pcall_rollback_4a.lua") // contract.pcall - if version >= 4 { + if currentVersion >= 4 { files = append(files, "feature_pcall_rollback_4b.lua") // pcall files = append(files, "feature_pcall_rollback_4c.lua") // xpcall } @@ -4013,16 +4114,19 @@ func TestPcallStateRollback3(t *testing.T) { code := readLuaCode(t, file) - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPrivNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( NewLuaTxAccount("user", 1, types.Aergo), NewLuaTxDeploy("user", "resolver", 0, resolver), - NewLuaTxDeploy("user", "A", 0, code).Constructor(fmt.Sprintf(`["%s","A"]`, nameToAddress("resolver"))), - NewLuaTxDeploy("user", "B", 0, code).Constructor(fmt.Sprintf(`["%s","B"]`, nameToAddress("resolver"))), - NewLuaTxDeploy("user", "C", 0, code).Constructor(fmt.Sprintf(`["%s","C"]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "A", 0, code).Constructor(fmt.Sprintf(`["%s","A",true]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "B", 0, code).Constructor(fmt.Sprintf(`["%s","B",true]`, nameToAddress("resolver"))), + NewLuaTxDeploy("user", "C", 0, code).Constructor(fmt.Sprintf(`["%s","C",true]`, nameToAddress("resolver"))), ) require.NoErrorf(t, err, "failed to deploy") @@ -4414,7 +4518,6 @@ func TestPcallStateRollback3(t *testing.T) { map[string]int{"A": 0, "B": 0}) } - } } func testStateRollback(t *testing.T, bc *DummyChain, script string, expected_state map[string]int, expected_amount map[string]int64) { @@ -4518,9 +4621,11 @@ func testDbStateRollback(t *testing.T, bc *DummyChain, script string, expected m func TestFeatureLuaCryptoVerifyProof(t *testing.T) { code := readLuaCode(t, "feature_crypto_verify_proof.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock(NewLuaTxAccount("user1", 1, types.Aergo), NewLuaTxDeploy("user1", "eth", 0, code)) @@ -4532,16 +4637,17 @@ func TestFeatureLuaCryptoVerifyProof(t *testing.T) { err = bc.Query("eth", `{"Name":"verifyProofHex"}`, "", `true`) require.NoErrorf(t, err, "failed to query") - } } func TestFeatureFeeDelegation(t *testing.T) { code := readLuaCode(t, "feature_feedelegation_1.lua") code2 := readLuaCode(t, "feature_feedelegation_2.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetPubNet(), SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnPubNet()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -4582,7 +4688,6 @@ func TestFeatureFeeDelegation(t *testing.T) { require.Errorf(t, err, "expect error") require.Containsf(t, err.Error(), "no 'check_delegation' function", "invalid error message") - } } /* @@ -4606,11 +4711,13 @@ func TestFeatureFeeDelegationLoop(t *testing.T) { abi.payable(default) abi.fee_delegation(query_no) ` - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(OnPubNet, SetHardForkVersion(version)) + bc, err := LoadDummyChain(OnPubNet, SetHardForkVersion(currentVersion)) if err != nil { t.Errorf("failed to create test database: %v", err) } + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() balance, _ := new(big.Int).SetString("1000000000000000000000", 10) @@ -4639,7 +4746,6 @@ func TestFeatureFeeDelegationLoop(t *testing.T) { err = bc.ConnectBlock(txs...) if err != nil { t.Error(err) - } } */ @@ -4647,9 +4753,11 @@ func TestFeatureFeeDelegationLoop(t *testing.T) { func TestContractIsolation(t *testing.T) { code := readLuaCode(t, "feature_isolation.lua") - for version := min_version; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -4680,7 +4788,6 @@ func TestContractIsolation(t *testing.T) { receipt = bc.GetReceipt(tx.Hash()) require.Equalf(t, ``, receipt.GetRet(), "contract call ret error") - } } ////////////////////////////////////////////////////////////////////// @@ -4806,9 +4913,15 @@ func execute_block(t *testing.T, bc *DummyChain, txns []*luaTxCall, expectedResu func TestComposableTransactions(t *testing.T) { code := readLuaCode(t, "feature_multicall.lua") - for version := min_version_multicall; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + if currentVersion < min_version_multicall { + t.Skipf("skipping test for version %d", currentVersion) + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -6077,16 +6190,21 @@ func TestComposableTransactions(t *testing.T) { ["return","%sender%"] ]`, ``, `"AmgMPiyZYr19kQ1kHFNiGenez1CRTBqNWqppj6gGZGEP6qszDGe1"`) - } } func TestContractMulticall(t *testing.T) { code1 := readLuaCode(t, "feature_multicall_contract.lua") code2 := readLuaCode(t, "feature_multicall.lua") - for version := min_version_multicall; version <= max_version; version++ { - bc, err := LoadDummyChain(SetHardForkVersion(version)) + if currentVersion < min_version_multicall { + t.Skipf("skipping test for version %d", currentVersion) + } + + bc, err := LoadDummyChain(SetHardForkVersion(currentVersion), RunOnAllNets()) require.NoErrorf(t, err, "failed to create dummy chain") + if bc == nil { + t.Skip("skipping test") + } defer bc.Release() err = bc.ConnectBlock( @@ -6258,7 +6376,6 @@ func TestContractMulticall(t *testing.T) { state, err = bc.GetAccountState("c3") assert.Equalf(t, int64(0), state.GetBalanceBigInt().Int64(), "balance error") - } } diff --git a/contract/vm_pool.go b/contract/vm_pool.go new file mode 100644 index 000000000..1db9bbfed --- /dev/null +++ b/contract/vm_pool.go @@ -0,0 +1,356 @@ +package contract + +import ( + "sync" + "net" + "fmt" + "strconv" + "math/rand" + "os" + "os/exec" + "time" + "path/filepath" + + "github.com/aergoio/aergo/v2/internal/enc/hex" + "github.com/aergoio/aergo/v2/contract/msg" +) + +var maxInstances int +var getCh chan *VmInstance +var freeCh chan *VmInstance +var closeCh chan bool +var repopulateCh chan bool +var once sync.Once +var VmPoolStarted bool + +func StartVMPool(numInstances int) { + once.Do(func() { + maxInstances = numInstances + // create channels for getting and freeing vm instances + getCh = make(chan *VmInstance, numInstances) + freeCh = make(chan *VmInstance, numInstances) + closeCh = make(chan bool) + repopulateCh = make(chan bool) + // start a goroutine to manage the vm instances + go vmPoolRoutine() + // wait for the vm pool to be started + for !VmPoolStarted { + time.Sleep(time.Millisecond * 25) + } + }) +} + +func StopVMPool() { + // stop the vm pool + closeCh <- true + // wait for the vm pool to be stopped + for VmPoolStarted { + time.Sleep(time.Millisecond * 25) + } +} + +func vmPoolRoutine() { + + // create vm instances + spawnVmInstances(maxInstances) + + // mark the vm pool as started + VmPoolStarted = true + + // wait for instances to be released + for { + select { + case vmInstance := <-freeCh: + // close the vm instance + vmInstance.close() + // replenish the pool + repopulatePool() + case <- repopulateCh: + repopulatePool() + case <- closeCh: + // close all instances + for _, vmInstance := range pool { + vmInstance.close() + } + // close the channels + close(getCh) + close(freeCh) + close(closeCh) + close(repopulateCh) + // mark the vm pool as stopped + VmPoolStarted = false + // exit the goroutine + return + } + } + +} + +//--------------------------------------------------------------------// +// exported functions + +func GetVmInstance() *VmInstance { + vmInstance := <-getCh + // notify the goroutine + if len(repopulateCh) == 0 { + repopulateCh <- true + } + return vmInstance +} + +func FreeVmInstance(vmInstance *VmInstance) { + if vmInstance != nil { + freeCh <- vmInstance + ctrLgr.Trace().Msg("VmInstance released") + } +} + +// flush and renew all vm instances +func FlushVmInstances() { + // first retrieve all vm instances, so when releasing the first one + // the pool is empty and then it will spawn many at once + list := []*VmInstance{} + num := len(getCh) + for i := 0; i < num; i++ { + vmInstance := GetVmInstance() + list = append(list, vmInstance) + } + for _, vmInstance := range list { + FreeVmInstance(vmInstance) + } + // wait until there is some instance on the getCh + for len(getCh) == 0 { + time.Sleep(time.Millisecond * 25) + } +} + +//--------------------------------------------------------------------// +// VmInstance type + +type VmInstance struct { + id uint64 + socketName string + secretKey [32]byte + listener *net.UnixListener + conn *net.UnixConn + pid int + used bool +} + +// pool of vm instances +var pool []*VmInstance + +// repopulate the pool with new vm instances +func repopulatePool() { + + for { + // check how many instances are available on the getCh + numAvailable := len(getCh) + // if the number of available instances is less than 5, spawn new ones + numToSpawn := maxInstances - numAvailable + if numToSpawn >= 5 { + spawnVmInstances(numToSpawn) + } else { + break + } + } + +} + +// spawn a number of vm instances +func spawnVmInstances(num int) { + var num_repeats int + + for i := 0; i < num; i++ { + // get a random id + var id uint64 + outer: + for { + id = rand.Uint64() + // check if it is already used + for _, vmInstance := range pool { + if vmInstance.id == id { + continue outer + } + } + break + } + + // get a random secret key + secretKey := [32]byte{} + rand.Read(secretKey[:]) + + // get a random name for the abstract unix domain socket + socketName := fmt.Sprintf("aergo-vm-%x", id) + + // create an abstract unix domain socket listener + listener, err := net.Listen("unix", "\x00"+socketName) + if err != nil { + ctrLgr.Error().Msg("Failed to create unix domain socket listener") + // try again + num_repeats++ + if num_repeats > 10 { + os.Exit(1) + } + i-- + continue + } + unixListener, ok := listener.(*net.UnixListener) + if !ok { + ctrLgr.Error().Msg("Failed to assign listener to *net.UnixListener") + listener.Close() + // try again + num_repeats++ + if num_repeats > 10 { + os.Exit(1) + } + i-- + continue + } + + // get the directory path of the current executable + var execDir string + execPath, err := os.Executable() + if err != nil { + ctrLgr.Error().Err(err).Msg("Failed to get executable path") + } else { + execDir = filepath.Dir(execPath) + } + + // try different paths for the external VM executable + execPath = os.Getenv("AERGOVM_PATH") + if execPath == "" { + execPath = filepath.Join(execDir, "aergovm") + // check if the file exists + if _, err := os.Stat(execPath); err != nil { + execPath = "./aergovm" + if _, err := os.Stat(execPath); err != nil { + execPath = "aergovm" + } + } + } + + // spawn the exernal VM executable process + cmd := exec.Command(execPath, strconv.Itoa(int(CurrentForkVersion)), map[bool]string{true: "1", false: "0"}[PubNet], socketName, hex.Encode(secretKey[:])) + err = cmd.Start() + if err != nil { + ctrLgr.Error().Err(err).Msg("Failed to spawn external VM process") + listener.Close() + // try again + num_repeats++ + if num_repeats > 10 { + os.Exit(1) + } + i-- + continue + } + // get the process id + pid := cmd.Process.Pid + ctrLgr.Trace().Msgf("Spawned external VM process with pid: %d", pid) + + // create a vm instance object + vmInstance := &VmInstance{ + id: id, + socketName: socketName, + secretKey: secretKey, + listener: unixListener, + conn: nil, + pid: pid, + used: false, + } + + // add the vm instance to the pool + pool = append(pool, vmInstance) + + } + + // keep track of the instances that should be removed + instancesToRemove := []*VmInstance{} + + // keep track of the new instances that are connected + instancesToRead := []*VmInstance{} + + // the timeout is 100 milliseconds for each vm instance + timeout := time.Millisecond * time.Duration(100 * num) + if timeout < time.Second { + timeout = time.Second + } + // set a common deadline for the accept calls + deadline := time.Now().Add(timeout) + + // iterate over all instances + for _, vmInstance := range pool { + // if this VM instance is not yet connected + if vmInstance.conn == nil { + // set a deadline for the accept call + vmInstance.listener.SetDeadline(deadline) + // wait for the incoming connection + var err error + vmInstance.conn, err = vmInstance.listener.AcceptUnix() + if err == nil { + // connection accepted + instancesToRead = append(instancesToRead, vmInstance) + } else { + ctrLgr.Error().Msgf("Failed to accept incoming connection: %v", err) + instancesToRemove = append(instancesToRemove, vmInstance) + } + } + } + + // remove the instances that are not connected + for _, vmInstance := range instancesToRemove { + vmInstance.close() + } + + // iterate over the instances that are connected + for _, vmInstance := range instancesToRead { + // wait for a message from the vm instance + message, err := msg.WaitForMessage(vmInstance.conn, deadline) + if err != nil { + ctrLgr.Error().Msgf("Failed to read incoming message: %v", err) + vmInstance.close() + continue + } + // check if the data is valid + if !isValidMessage(vmInstance, message) { + ctrLgr.Error().Msg("Invalid message received") + vmInstance.close() + continue + } + // send the instance to the getCh + getCh <- vmInstance + } + +} + +func isValidMessage(vmInstance *VmInstance, message []byte) bool { + if string(message) == "ready" { + return true + } + return false +} + +// this should ONLY be called by the vmPoolRoutine. use FreeVmInstance() to release a vm instance +func (vmInstance *VmInstance) close() { + if vmInstance != nil { + // close the connections + if vmInstance.listener != nil { + vmInstance.listener.Close() + } + if vmInstance.conn != nil { + vmInstance.conn.Close() + } + // terminate the process + process, err := os.FindProcess(vmInstance.pid) + if err == nil { + process.Kill() + } + // remove the vm instance from the pool + for i, v := range pool { + if v == vmInstance { + pool = append(pool[:i], pool[i+1:]...) + break + } + } + } +} diff --git a/contract/vm_server.go b/contract/vm_server.go new file mode 100644 index 000000000..6648ca546 --- /dev/null +++ b/contract/vm_server.go @@ -0,0 +1,386 @@ +package contract +/* +#include "db_module.h" +*/ +import "C" +import ( + "encoding/json" + "encoding/binary" + "errors" + "strconv" + "unsafe" + "time" + "github.com/aergoio/aergo/v2/contract/msg" + "github.com/aergoio/aergo/v2/types" +) + +// convert the arguments to a single string containing the JSON array +func (ce *executor) convertArgsToJSON() (string, error) { + if ce.ci == nil || ce.ci.Args == nil { + return "[]", nil + } + args, err := json.Marshal(ce.ci.Args) + if err != nil { + return "", err + } + return string(args), nil +} + +func (ce *executor) call(extractUsedGas bool) { + + if ce.err != nil { + return + } + + /* + defer func() { + if PubNet == false { + C.db_release_resource() + // test: A -> B -> C, what happens if A and C use the same db? (same contract) + // maybe contract C should release only its own resources + // the struct could store the instance id, and release only its own resources + } + }() + */ + + if ce.isView == true { + ce.ctx.nestedView++ + defer func() { + ce.ctx.nestedView-- + }() + } + + // what to send: + // - address: types.EncodeAddress(ce.ctx.curContract.contractId) string + // - bytecode: ce.code []byte + // - function name: ce.fname string + // - args: ce.ci.Args []interface{} + // - gas: ce.contractGasLimit uint64 + // - sender: types.EncodeAddress(ce.ctx.curContract.sender) string + // - hasParent: ce.ctx.callDepth > 1 + // - isFeeDelegation: ce.ctx.isFeeDelegation bool + + address := types.EncodeAddress(ce.ctx.curContract.contractId) + bytecode := string(ce.code) + fname := ce.fname + if ce.isAutoload == true { + fname = "autoload:" + ce.fname + } + // convert the parameters to strings + args, err := ce.convertArgsToJSON() + if err != nil { + ce.err = err + return + } + //gas := strconv.FormatUint(ce.contractGasLimit, 10) + gas := string((*[8]byte)(unsafe.Pointer(&ce.contractGasLimit))[:]) + sender := types.EncodeAddress(ce.ctx.curContract.sender) + hasParent := strconv.FormatBool(ce.ctx.callDepth > 1) + isFeeDelegation := strconv.FormatBool(ce.ctx.isFeeDelegation) + abiError := "" + if ce.abiErr != nil { + abiError = ce.abiErr.Error() + } + + // build the message + message := msg.SerializeMessage("execute", address, bytecode, fname, args, gas, sender, hasParent, isFeeDelegation, abiError) + + // send the execution request to the VM instance + err = ce.SendMessage(message) + if err != nil { + ce.err = err + return + } + + // if this is the first call, wait messages in a loop + //if ce.ctx.callDepth == 1 { + // return MessageLoop() + //} + + // wait for and process messages in a loop + result, err := ce.MessageLoop() + + if extractUsedGas && len(result) >= 8 { + // extract the used gas from the result + ce.usedGas = binary.LittleEndian.Uint64([]byte(result[:8])) + result = result[8:] + } + + // return the result from the VM instance + ce.jsonRet = result + ce.err = err + + // when a message arrives, process it + // when the first VM finishes (or timeout occurs) return from this function + +} + +// only messages from the last/top contract can be processed +// (a hacker could send as if it is from another contract) +// also use encryption with diff key for each instance + +// incoming messages processed: +// 1. only from the last/top contract +// 2. only for 'vm_callback' functions, and responses + + +func (ce *executor) MessageLoop() (result string, err error) { + + // wait for messages in a loop + for { + message, err := ce.WaitForMessage() + if err != nil { + return "", err + } + // deserialize the message + args, err := msg.DeserializeMessage(message) + if err != nil { + return "", err + } + // extract the command, arguments and whether it is within a view function + if len(args) < 2 { + return "", errors.New("[MessageLoop] invalid arguments from VM") + } + command := args[0] + inView := args[len(args)-1] == "1" + args = args[1:len(args)-1] + // process the request + if inView { + ce.ctx.nestedView++ + } + result, err = ce.ProcessCommand(command, args) + if inView { + ce.ctx.nestedView-- + } + // if the VM finished, return the result + if command == "return" { + return result, err + } + // serialize the response + var errMsg string + if err != nil { + errMsg = err.Error() + } + response := msg.SerializeMessage(result, errMsg) + // send the response + err = ce.SendMessage(response) + if err != nil { + return "", err // different type of error + } + } + +} + +// sends a message to the VM instance +func (ce *executor) SendMessage(message []byte) (err error) { + return msg.SendMessage(ce.vmInstance.conn, message) +} + +// waits for a message from the VM instance +func (ce *executor) WaitForMessage() ([]byte, error) { + + if ce.ctx.callDepth == 1 && ce.ctx.deadline.IsZero() { + // define a global deadline for contract execution + ce.ctx.deadline = time.Now().Add(250 * time.Millisecond) + } + + return msg.WaitForMessage(ce.vmInstance.conn, ce.ctx.deadline) +} + +// process the command from the VM instance +func (ce *executor) ProcessCommand(command string, args []string) (result string, err error) { + + ctx := ce.ctx + + switch command { + + // return from call + + case "return": + return ce.handleReturnFromCall(args) + + // state variables + + case "set": + return ctx.handleSetVariable(args) + case "get": + return ctx.handleGetVariable(args) + case "del": + return ctx.handleDelVariable(args) + + // contract + + case "deploy": + return ctx.handleDeploy(args) + case "call": + return ctx.handleCall(args) + case "delegate-call": + return ctx.handleDelegateCall(args) + case "send": + return ctx.handleSend(args) + case "balance": + return ctx.handleGetBalance(args) + case "event": + return ctx.handleEvent(args) + + // system + + case "toPubkey": + return ctx.handleToPubkey(args) + case "toAddress": + return ctx.handleToAddress(args) + case "isContract": + return ctx.handleIsContract(args) + case "getContractId": + return ctx.handleGetContractId() + case "getAmount": + return ctx.handleGetAmount() + case "getBlockNo": + return ctx.handleGetBlockNo() + case "getTimeStamp": + return ctx.handleGetTimeStamp() + case "getPrevBlockHash": + return ctx.handleGetPrevBlockHash() + case "getTxHash": + return ctx.handleGetTxHash() + case "getOrigin": + return ctx.handleGetOrigin() + case "randomInt": + return ctx.handleRandomInt(args) + case "print": + return ctx.handlePrint(args) + + // name service + + case "nameResolve": + return ctx.handleNameResolve(args) + + // governance + + case "governance": + return ctx.handleGovernance(args) + case "getStaking": + return ctx.handleGetStaking(args) + + // crypto + + case "sha256": + return ctx.handleCryptoSha256(args) + case "keccak256": + return ctx.handleCryptoKeccak256(args) + case "ecVerify": + return ctx.handleECVerify(args) + case "verifyEthStorageProof": + return ctx.handleCryptoVerifyEthStorageProof(args) + + // db + + case "dbExec": + return ctx.handleDbExec(args) + case "dbQuery": + return ctx.handleDbQuery(args) + case "dbPrepare": + return ctx.handleDbPrepare(args) + case "stmtExec": + return ctx.handleStmtExec(args) + case "stmtQuery": + return ctx.handleStmtQuery(args) + case "stmtColumnInfo": + return ctx.handleStmtColumnInfo(args) + case "rsNext": + return ctx.handleRsNext(args) + case "rsGet": + return ctx.handleRsGet(args) + //case "rsClose": + // return ctx.handleRsClose(args) + case "lastInsertRowid": + return ctx.handleLastInsertRowid(args) + case "dbOpenWithSnapshot": + return ctx.handleDbOpenWithSnapshot(args) + case "dbGetSnapshot": + return ctx.handleDbGetSnapshot(args) + + // internal + + case "setRecoveryPoint": + return ctx.handleSetRecoveryPoint() + case "clearRecovery": + return ctx.handleClearRecovery(args) + + } + + return "", errors.New("invalid command: " + command) + +} + +// handle the return from a call +func (ce *executor) handleReturnFromCall(args []string) (result string, err error) { + + if len(args) != 2 { + return "", errors.New("[ReturnFromVM] invalid return value from contract") + } + result = args[0] // JSON + errStr := args[1] // error message + + /* + // add the used gas and check if the execution ran out of gas + err = ce.processUsedGas(result) + + if errStr != "" { + if err != nil { + err = errors.New("[ReturnFromVM] 1: " + err.Error() + ", 2: " + errStr) + } else { + err = errors.New(errStr) + } + } + */ + + if errStr != "" { + err = errors.New(errStr) + } + + return result, err +} + +/* +// add the used gas and check if the execution ran out of gas +func (ce *executor) processUsedGas(result string) (err error) { + + // check if the used gas is a valid uint64 value + if len(result) < 8 { + return errors.New("[ReturnFromVM] invalid used gas value from contract") + } + // convert the used gas to a uint64 value + usedGas := binary.LittleEndian.Uint64([]byte(result[:8])) + + // add the gas used by this contract to the total gas + ce.ctx.accumulatedGas += usedGas + + // check if the contract ran out of the transaction gas limit + if ce.ctx.accumulatedGas >= ce.ctx.gasLimit { + return errors.New("[ReturnFromVM] contract ran out of the transaction gas limit") + } + + // check if the contract ran out of the contract gas limit + if usedGas >= ce.contractGasLimit { + return errors.New("[ReturnFromVM] contract ran out of the contract gas limit") + } + + return nil +} +*/ + + + + +// sent when a VM is created: +// hardfork version + IsPublic + abstract domain socket name + secret key + +// sent when a contract is called: +// sender + IsFeeDelegation + + +// in the case of timeout: +// the VM pool should: +// - close all connections to the used VMs +// - kill all the processes linked with the current execution diff --git a/contract/vm_state.go b/contract/vm_state.go index 27e2bd62f..12cd39245 100644 --- a/contract/vm_state.go +++ b/contract/vm_state.go @@ -72,12 +72,18 @@ func newContractInfo(cs *callState, sender, contractId []byte, rp uint64, amount } } + +//////////////////////////////////////////////////////////////////////////////// +// State Recovery +//////////////////////////////////////////////////////////////////////////////// + type recoveryEntry struct { seq int amount *big.Int senderState *state.AccountState senderNonce uint64 callState *callState + eventCount int onlySend bool isDeploy bool sqlSaveName *string @@ -85,7 +91,7 @@ type recoveryEntry struct { prev *recoveryEntry } -func (re *recoveryEntry) recovery(bs *state.BlockState) error { +func (re *recoveryEntry) revertState(ctx *vmContext) error { var zero big.Int cs := re.callState @@ -107,10 +113,17 @@ func (re *recoveryEntry) recovery(bs *state.BlockState) error { re.senderState.SetNonce(re.senderNonce) } + // if the contract state is not stored, do not restore it if cs == nil { return nil } + // restore the event count + if ctx.blockInfo.ForkVersion >= 4 { + ctx.events = ctx.events[:re.eventCount] + ctx.eventCount = int32(re.eventCount) + } + // restore the contract state if re.stateRevision != -1 { err := cs.ctrState.Rollback(re.stateRevision) @@ -122,7 +135,7 @@ func (re *recoveryEntry) recovery(bs *state.BlockState) error { if err != nil { return newDbSystemError(err) } - bs.RemoveCache(cs.ctrState.GetAccountID()) + ctx.bs.RemoveCache(cs.ctrState.GetAccountID()) } } @@ -145,45 +158,72 @@ func (re *recoveryEntry) recovery(bs *state.BlockState) error { return nil } -func setRecoveryPoint(aid types.AccountID, ctx *vmContext, senderState *state.AccountState, - cs *callState, amount *big.Int, isSend, isDeploy bool) (int, error) { +func setRecoveryPoint( + aid types.AccountID, + ctx *vmContext, + senderState *state.AccountState, + cs *callState, + amount *big.Int, + onlySend, isDeploy bool, +) (int, error) { var seq int + + // get the previous recovery entry prev := ctx.lastRecoveryEntry + + // get the next sequence number if prev != nil { seq = prev.seq + 1 } else { seq = 1 } + + // get the sender nonce var nonce uint64 if senderState != nil { nonce = senderState.Nonce() } + + // create the recovery entry re := &recoveryEntry{ seq, amount, senderState, nonce, cs, - isSend, + 0, + onlySend, isDeploy, nil, -1, prev, } ctx.lastRecoveryEntry = re - if isSend { + + // if it's just aergo transfer, do not store the contract state + if onlySend { return seq, nil } + + // get the current event count + re.eventCount = len(ctx.events) + + // get the contract state snapshot re.stateRevision = cs.ctrState.Snapshot() + + // get the contract SQL db transaction tx := cs.tx if tx != nil { saveName := fmt.Sprintf("%s_%p", aid.String(), &re) err := tx.subSavepoint(saveName) if err != nil { - return seq, err + ctx.lastRecoveryEntry = prev + return -1, err } re.sqlSaveName = &saveName } + + // return the sequence number return seq, nil } diff --git a/fee/fee.go b/fee/fee.go index 8652a4c22..246d54f8a 100644 --- a/fee/fee.go +++ b/fee/fee.go @@ -48,7 +48,7 @@ func NewZeroFee() *big.Int { //---------------------------------------------------------------// // calc fee -// fee = gas price * gas +// fee = used gas * gas price func CalcFee(gasPrice *big.Int, gas uint64) *big.Int { return new(big.Int).Mul(gasPrice, new(big.Int).SetUint64(gas)) } diff --git a/libtool/src/luajit b/libtool/src/luajit index eda86b19a..d475918d6 160000 --- a/libtool/src/luajit +++ b/libtool/src/luajit @@ -1 +1 @@ -Subproject commit eda86b19a18b0f1da53bdc3c0f281124d4fe4260 +Subproject commit d475918d639181a08ce34615aea920ff213f6b91 diff --git a/rpc/swagger/swagger.yaml b/rpc/swagger/swagger.yaml index 19fa2775e..78a5553c8 100644 --- a/rpc/swagger/swagger.yaml +++ b/rpc/swagger/swagger.yaml @@ -824,8 +824,6 @@ paths: "StateTrace": 0, "VerifyBlock": 0, "NumWorkers": 16, - "NumLStateClosers": 2, - "CloseLimit": 100, }, "orphan": 0, "testmode": false, diff --git a/tests/config-sbp.toml b/tests/config-sbp.toml index 10b3304d7..f0138af6b 100644 --- a/tests/config-sbp.toml +++ b/tests/config-sbp.toml @@ -54,8 +54,6 @@ maxanchorcount = "20" verifiercount = "1" forceresetheight = "0" numworkers = "1" -numclosers = "1" -closelimit = "100" [mempool] showmetrics = false