diff --git a/torchcwrap.lua b/torchcwrap.lua index 9f67c727..ab0df43e 100644 --- a/torchcwrap.lua +++ b/torchcwrap.lua @@ -19,7 +19,7 @@ types.Tensor = { end return table.concat(txt, '\n') end, - + check = function(arg, idx) if arg.dim then return string.format("(arg%d = luaT_toudata(L, %d, torch_Tensor)) && (arg%d->nDimension == %d)", arg.i, idx, arg.i, arg.dim) @@ -43,7 +43,7 @@ types.Tensor = { error('unknown default tensor type value') end end, - + carg = function(arg) return string.format('arg%d', arg.i) end, @@ -51,7 +51,7 @@ types.Tensor = { creturn = function(arg) return string.format('arg%d', arg.i) end, - + precall = function(arg) local txt = {} if arg.default and arg.returned then @@ -144,7 +144,7 @@ types.IndexTensor = { end return table.concat(txt, '\n') end, - + check = function(arg, idx) return string.format('(arg%d = luaT_toudata(L, %d, "torch.LongTensor"))', arg.i, idx) end, @@ -163,7 +163,7 @@ types.IndexTensor = { init = function(arg) return string.format('arg%d = THLongTensor_new();', arg.i) end, - + carg = function(arg) return string.format('arg%d', arg.i) end, @@ -171,7 +171,7 @@ types.IndexTensor = { creturn = function(arg) return string.format('arg%d', arg.i) end, - + precall = function(arg) local txt = {} if arg.default and arg.returned then @@ -213,7 +213,7 @@ for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor" return typename end end, - + declare = function(arg) local txt = {} table.insert(txt, string.format("TH%s *arg%d = NULL;", typename, arg.i)) @@ -222,7 +222,7 @@ for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor" end return table.concat(txt, '\n') end, - + check = function(arg, idx) if arg.dim then return string.format('(arg%d = luaT_toudata(L, %d, "torch.%s")) && (arg%d->nDimension == %d)', arg.i, idx, typename, arg.i, arg.dim) @@ -236,7 +236,7 @@ for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor" return string.format("arg%d_idx = %d;", arg.i, idx) end end, - + init = function(arg) if type(arg.default) == 'boolean' then return string.format('arg%d = TH%s_new();', arg.i, typename) @@ -254,7 +254,7 @@ for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor" creturn = function(arg) return string.format('arg%d', arg.i) end, - + precall = function(arg) local txt = {} if arg.default and arg.returned then @@ -316,6 +316,7 @@ for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor" table.insert(txt, string.format('do')) table.insert(txt, string.format('{')) table.insert(txt, string.format(' arg%d_size++;', arg.i)) + table.insert(txt, string.format(' lua_checkstack(L, 1);')) table.insert(txt, string.format(' lua_rawgeti(L, %d, arg%d_size);', idx, arg.i)) table.insert(txt, string.format('}')) table.insert(txt, string.format('while (!lua_isnil(L, -1));')) @@ -371,7 +372,7 @@ types.LongArg = { error('LongArg cannot have a default value') end end, - + check = function(arg, idx) return string.format("torch_islongargs(L, %d)", idx) end, @@ -379,7 +380,7 @@ types.LongArg = { read = function(arg, idx) return string.format("arg%d = torch_checklongargs(L, %d);", arg.i, idx) end, - + carg = function(arg, idx) return string.format('arg%d', arg.i) end, @@ -387,7 +388,7 @@ types.LongArg = { creturn = function(arg, idx) return string.format('arg%d', arg.i) end, - + precall = function(arg) local txt = {} if arg.returned then @@ -407,11 +408,11 @@ types.LongArg = { table.insert(txt, string.format('THLongStorage_free(arg%d);', arg.i)) end return table.concat(txt, '\n') - end + end } types.charoption = { - + helpname = function(arg) if arg.values then return "(" .. table.concat(arg.values, '|') .. ")" @@ -430,7 +431,7 @@ types.charoption = { init = function(arg) return string.format("arg%d = &arg%d_default;", arg.i, arg.i) end, - + check = function(arg, idx) local txt = {} local txtv = {} @@ -439,23 +440,23 @@ types.charoption = { table.insert(txtv, string.format("*arg%d == '%s'", arg.i, value)) end table.insert(txt, table.concat(txtv, ' || ')) - table.insert(txt, ')') + table.insert(txt, ')') return table.concat(txt, '') end, read = function(arg, idx) end, - + carg = function(arg, idx) return string.format('arg%d', arg.i) end, creturn = function(arg, idx) end, - + precall = function(arg) end, postcall = function(arg) - end + end }