diff --git a/init.c b/init.c index ad2b257d..08eedba7 100644 --- a/init.c +++ b/init.c @@ -35,29 +35,17 @@ extern void torch_DoubleTensorOperator_init(lua_State *L); extern void torch_TensorMath_init(lua_State *L); -static void luaTorchErrorHandlerFunction(const char *msg, void *data) -{ - lua_State *L = data; - luaL_error(L, msg); -} - -static void luaTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data) -{ - lua_State *L = data; - luaL_argcheck(L, 0, argNumber, msg); -} - LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L); int luaopen_libtorch(lua_State *L) { - THSetErrorHandler(luaTorchErrorHandlerFunction, L); - THSetArgErrorHandler(luaTorchArgErrorHandlerFunction, L); lua_newtable(L); lua_pushvalue(L, -1); lua_setglobal(L, "torch"); + torch_utils_init(L); + torch_File_init(L); torch_ByteStorage_init(L); @@ -91,7 +79,6 @@ int luaopen_libtorch(lua_State *L) torch_TensorMath_init(L); - torch_utils_init(L); torch_random_init(L); // Create 'torch.Allocator' type. diff --git a/init.lua b/init.lua index 5f250282..29d54e1e 100644 --- a/init.lua +++ b/init.lua @@ -13,6 +13,16 @@ end require "paths" paths.require "libtorch" +-- Keep track of all thread local variables torch. +-- if a Lua VM is passed to another thread thread local +-- variables need to be updated. +function torch.updatethreadlocals() + torch.updateerrorhandlers() + local tracking = torch._heaptracking + if tracking == nil then tracking = false end + torch.setheaptracking(tracking) +end + --- package stuff function torch.packageLuaPath(name) if not name then diff --git a/test/test.lua b/test/test.lua index a7a8ba46..aa309b51 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2773,6 +2773,21 @@ function torchtest.nonzero() end +function torchtest.testheaptracking() + local oldheaptracking = torch._heaptracking + if oldheaptracking == nil then + oldheaptracking = false + end + torch.setheaptracking(true) + mytester:assert(torch._heaptracking == true, 'Heap tracking expected true') + + torch.setheaptracking(false) + mytester:assert(torch._heaptracking == false, 'Heap tracking expected false') + + -- put heap tracking to its original state + torch.setheaptracking(oldheaptracking) +end + function torch.test(tests) torch.setheaptracking(true) math.randomseed(os.time()) diff --git a/utils.c b/utils.c index 0a180209..35fdae40 100644 --- a/utils.c +++ b/utils.c @@ -197,6 +197,9 @@ static void luaTorchGCFunction(void *data) static int torch_setheaptracking(lua_State *L) { int enabled = luaT_checkboolean(L,1); + lua_getglobal(L, "torch"); + lua_pushboolean(L, enabled); + lua_setfield(L, -2, "_heaptracking"); if(enabled) { THSetGCHandler(luaTorchGCFunction, L); } else { @@ -205,6 +208,25 @@ static int torch_setheaptracking(lua_State *L) return 0; } +static void luaTorchErrorHandlerFunction(const char *msg, void *data) +{ + lua_State *L = data; + luaL_error(L, msg); +} + +static void luaTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data) +{ + lua_State *L = data; + luaL_argcheck(L, 0, argNumber, msg); +} + +static int torch_updateerrorhandlers(lua_State *L) +{ + THSetErrorHandler(luaTorchErrorHandlerFunction, L); + THSetArgErrorHandler(luaTorchArgErrorHandlerFunction, L); + return 0; +} + static const struct luaL_Reg torch_utils__ [] = { {"getdefaulttensortype", torch_lua_getdefaulttensortype}, {"isatty", torch_isatty}, @@ -227,10 +249,12 @@ static const struct luaL_Reg torch_utils__ [] = { {"version", luaT_lua_version}, {"pointer", luaT_lua_pointer}, {"setheaptracking", torch_setheaptracking}, + {"updateerrorhandlers", torch_updateerrorhandlers}, {NULL, NULL} }; void torch_utils_init(lua_State *L) { + torch_updateerrorhandlers(L); luaT_setfuncs(L, torch_utils__, 0); }