Skip to content

Commit

Permalink
add lua function to update thread locals
Browse files Browse the repository at this point in the history
  • Loading branch information
koraykv committed Sep 26, 2015
1 parent cedf29e commit d55f749
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 15 deletions.
17 changes: 2 additions & 15 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,17 @@ extern void torch_DoubleTensorOperator_init(lua_State *L);

extern void torch_TensorMath_init(lua_State *L);

static void luaTorchErrorHandlerFunction(const char *msg, void *data)
{
lua_State *L = data;
luaL_error(L, msg);
}

static void luaTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data)
{
lua_State *L = data;
luaL_argcheck(L, 0, argNumber, msg);
}

LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L);

int luaopen_libtorch(lua_State *L)
{
THSetErrorHandler(luaTorchErrorHandlerFunction, L);
THSetArgErrorHandler(luaTorchArgErrorHandlerFunction, L);

lua_newtable(L);
lua_pushvalue(L, -1);
lua_setglobal(L, "torch");

torch_utils_init(L);

torch_File_init(L);

torch_ByteStorage_init(L);
Expand Down Expand Up @@ -91,7 +79,6 @@ int luaopen_libtorch(lua_State *L)

torch_TensorMath_init(L);

torch_utils_init(L);
torch_random_init(L);

// Create 'torch.Allocator' type.
Expand Down
10 changes: 10 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ end
require "paths"
paths.require "libtorch"

-- Keep track of all thread local variables torch.
-- if a Lua VM is passed to another thread thread local
-- variables need to be updated.
function torch.updatethreadlocals()
torch.updateerrorhandlers()
local tracking = torch._heaptracking
if tracking == nil then tracking = false end
torch.setheaptracking(tracking)
end

--- package stuff
function torch.packageLuaPath(name)
if not name then
Expand Down
15 changes: 15 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,21 @@ function torchtest.nonzero()

end

function torchtest.testheaptracking()
local oldheaptracking = torch._heaptracking
if oldheaptracking == nil then
oldheaptracking = false
end
torch.setheaptracking(true)
mytester:assert(torch._heaptracking == true, 'Heap tracking expected true')

torch.setheaptracking(false)
mytester:assert(torch._heaptracking == false, 'Heap tracking expected false')

-- put heap tracking to its original state
torch.setheaptracking(oldheaptracking)
end

function torch.test(tests)
torch.setheaptracking(true)
math.randomseed(os.time())
Expand Down
24 changes: 24 additions & 0 deletions utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ static void luaTorchGCFunction(void *data)
static int torch_setheaptracking(lua_State *L)
{
int enabled = luaT_checkboolean(L,1);
lua_getglobal(L, "torch");
lua_pushboolean(L, enabled);
lua_setfield(L, -2, "_heaptracking");
if(enabled) {
THSetGCHandler(luaTorchGCFunction, L);
} else {
Expand All @@ -205,6 +208,25 @@ static int torch_setheaptracking(lua_State *L)
return 0;
}

static void luaTorchErrorHandlerFunction(const char *msg, void *data)
{
lua_State *L = data;
luaL_error(L, msg);
}

static void luaTorchArgErrorHandlerFunction(int argNumber, const char *msg, void *data)
{
lua_State *L = data;
luaL_argcheck(L, 0, argNumber, msg);
}

static int torch_updateerrorhandlers(lua_State *L)
{
THSetErrorHandler(luaTorchErrorHandlerFunction, L);
THSetArgErrorHandler(luaTorchArgErrorHandlerFunction, L);
return 0;
}

static const struct luaL_Reg torch_utils__ [] = {
{"getdefaulttensortype", torch_lua_getdefaulttensortype},
{"isatty", torch_isatty},
Expand All @@ -227,10 +249,12 @@ static const struct luaL_Reg torch_utils__ [] = {
{"version", luaT_lua_version},
{"pointer", luaT_lua_pointer},
{"setheaptracking", torch_setheaptracking},
{"updateerrorhandlers", torch_updateerrorhandlers},
{NULL, NULL}
};

void torch_utils_init(lua_State *L)
{
torch_updateerrorhandlers(L);
luaT_setfuncs(L, torch_utils__, 0);
}

0 comments on commit d55f749

Please sign in to comment.