From b2b35a6ba9ce8245030510ffe4b8558d68d309fd Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 13 Feb 2014 17:40:19 +0100 Subject: [PATCH] repackaged wrap into a standalone cwrap module --- cinterface.lua | 308 ++++++++++++++++++++++++++++++++++++++++++++++++ init.lua | 314 +------------------------------------------------ types.lua | 263 ++--------------------------------------- 3 files changed, 320 insertions(+), 565 deletions(-) create mode 100644 cinterface.lua diff --git a/cinterface.lua b/cinterface.lua new file mode 100644 index 0000000..d6c26bb --- /dev/null +++ b/cinterface.lua @@ -0,0 +1,308 @@ +local CInterface = {} + +function CInterface.new() + self = {} + self.txt = {} + self.registry = {} + setmetatable(self, {__index=CInterface}) + return self +end + +function CInterface:luaname2wrapname(name) + return string.format("wrapper_%s", name) +end + +function CInterface:print(str) + table.insert(self.txt, str) +end + +function CInterface:wrap(luaname, ...) + local txt = self.txt + local varargs = {...} + + assert(#varargs > 0 and #varargs % 2 == 0, 'must provide both the C function name and the corresponding arguments') + + -- add function to the registry + table.insert(self.registry, {name=luaname, wrapname=self:luaname2wrapname(luaname)}) + + table.insert(txt, string.format("static int %s(lua_State *L)", self:luaname2wrapname(luaname))) + table.insert(txt, "{") + table.insert(txt, "int narg = lua_gettop(L);") + + if #varargs == 2 then + local cfuncname = varargs[1] + local args = varargs[2] + + local helpargs, cargs, argcreturned = self:__writeheaders(txt, args) + self:__writechecks(txt, args) + + table.insert(txt, 'else') + table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(helpargs, ' '))) + + self:__writecall(txt, args, cfuncname, cargs, argcreturned) + else + local allcfuncname = {} + local allargs = {} + local allhelpargs = {} + local allcargs = {} + local allargcreturned = {} + + table.insert(txt, "int argset = 0;") + + for k=1,#varargs/2 do + allcfuncname[k] = varargs[(k-1)*2+1] + allargs[k] = varargs[(k-1)*2+2] + end + + local argoffset = 0 + for k=1,#varargs/2 do + allhelpargs[k], allcargs[k], allargcreturned[k] = self:__writeheaders(txt, allargs[k], argoffset) + argoffset = argoffset + #allargs[k] + end + + for k=1,#varargs/2 do + self:__writechecks(txt, allargs[k], k) + end + + table.insert(txt, 'else') + local allconcathelpargs = {} + for k=1,#varargs/2 do + table.insert(allconcathelpargs, table.concat(allhelpargs[k], ' ')) + end + table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(allconcathelpargs, ' | '))) + + for k=1,#varargs/2 do + if k == 1 then + table.insert(txt, string.format('if(argset == %d)', k)) + else + table.insert(txt, string.format('else if(argset == %d)', k)) + end + table.insert(txt, '{') + self:__writecall(txt, allargs[k], allcfuncname[k], allcargs[k], allargcreturned[k]) + table.insert(txt, '}') + end + + table.insert(txt, 'return 0;') + end + + table.insert(txt, '}') + table.insert(txt, '') +end + +function CInterface:register(name) + local txt = self.txt + table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) + for _,reg in ipairs(self.registry) do + table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) + end + table.insert(txt, '{NULL, NULL}') + table.insert(txt, '};') + table.insert(txt, '') + self.registry = {} +end + +function CInterface:clearhistory() + self.txt = {} + self.registry = {} +end + +function CInterface:tostring() + return table.concat(self.txt, '\n') +end + +function CInterface:tofile(filename) + local f = io.open(filename, 'w') + f:write(table.concat(self.txt, '\n')) + f:close() +end + +local function bit(p) + return 2 ^ (p - 1) -- 1-based indexing +end + +local function hasbit(x, p) + return x % (p + p) >= p +end + +local function beautify(txt) + local indent = 0 + for i=1,#txt do + if txt[i]:match('}') then + indent = indent - 2 + end + if indent > 0 then + txt[i] = string.rep(' ', indent) .. txt[i] + end + if txt[i]:match('{') then + indent = indent + 2 + end + end +end + +local function tableinsertcheck(tbl, stuff) + if stuff and not stuff:match('^%s*$') then + table.insert(tbl, stuff) + end +end + +function CInterface:__writeheaders(txt, args, argoffset) + local argtypes = self.argtypes + local helpargs = {} + local cargs = {} + local argcreturned + argoffset = argoffset or 0 + + for i,arg in ipairs(args) do + arg.i = i+argoffset + arg.args = args -- in case we want to do stuff depending on other args + assert(argtypes[arg.name], 'unknown type ' .. arg.name) + setmetatable(arg, {__index=argtypes[arg.name]}) + arg.__metatable = argtypes[arg.name] + tableinsertcheck(txt, arg:declare()) + local helpname = arg:helpname() + if arg.returned then + helpname = string.format('*%s*', helpname) + end + if arg.invisible and arg.default == nil then + error('Invisible arguments must have a default! How could I guess how to initialize it?') + end + if arg.default ~= nil then + if not arg.invisible then + table.insert(helpargs, string.format('[%s]', helpname)) + end + elseif not arg.creturned then + table.insert(helpargs, helpname) + end + if arg.creturned then + if argcreturned then + error('A C function can only return one argument!') + end + if arg.default ~= nil then + error('Obviously, an "argument" returned by a C function cannot have a default value') + end + if arg.returned then + error('Options "returned" and "creturned" are incompatible') + end + argcreturned = arg + else + table.insert(cargs, arg:carg()) + end + end + + return helpargs, cargs, argcreturned +end + +function CInterface:__writechecks(txt, args, argset) + local argtypes = self.argtypes + + local multiargset = argset + argset = argset or 1 + + local nopt = 0 + for i,arg in ipairs(args) do + if arg.default ~= nil and not arg.invisible then + nopt = nopt + 1 + end + end + + for variant=0,math.pow(2, nopt)-1 do + local opt = 0 + local currentargs = {} + local optargs = {} + local hasvararg = false + for i,arg in ipairs(args) do + if arg.invisible then + table.insert(optargs, arg) + elseif arg.default ~= nil then + opt = opt + 1 + if hasbit(variant, bit(opt)) then + table.insert(currentargs, arg) + else + table.insert(optargs, arg) + end + elseif not arg.creturned then + table.insert(currentargs, arg) + end + end + + for _,arg in ipairs(args) do + if arg.vararg then + if hasvararg then + error('Only one argument can be a "vararg"!') + end + hasvararg = true + end + end + + if hasvararg and not currentargs[#currentargs].vararg then + error('Only the last argument can be a "vararg"') + end + + local compop + if hasvararg then + compop = '>=' + else + compop = '==' + end + + if variant == 0 and argset == 1 then + table.insert(txt, string.format('if(narg %s %d', compop, #currentargs)) + else + table.insert(txt, string.format('else if(narg %s %d', compop, #currentargs)) + end + + for stackidx, arg in ipairs(currentargs) do + table.insert(txt, string.format("&& %s", arg:check(stackidx))) + end + table.insert(txt, ')') + table.insert(txt, '{') + + if multiargset then + table.insert(txt, string.format('argset = %d;', argset)) + end + + for stackidx, arg in ipairs(currentargs) do + tableinsertcheck(txt, arg:read(stackidx)) + end + + for _,arg in ipairs(optargs) do + tableinsertcheck(txt, arg:init()) + end + + table.insert(txt, '}') + + end +end + +function CInterface:__writecall(txt, args, cfuncname, cargs, argcreturned) + local argtypes = self.argtypes + + for _,arg in ipairs(args) do + tableinsertcheck(txt, arg:precall()) + end + + if argcreturned then + table.insert(txt, string.format('%s = %s(%s);', argtypes[argcreturned.name].creturn(argcreturned), cfuncname, table.concat(cargs, ','))) + else + table.insert(txt, string.format('%s(%s);', cfuncname, table.concat(cargs, ','))) + end + + for _,arg in ipairs(args) do + tableinsertcheck(txt, arg:postcall()) + end + + local nret = 0 + if argcreturned then + nret = nret + 1 + end + for _,arg in ipairs(args) do + if arg.returned then + nret = nret + 1 + end + end + table.insert(txt, string.format('return %d;', nret)) +end + +return CInterface + + diff --git a/init.lua b/init.lua index bd4199d..dabe086 100644 --- a/init.lua +++ b/init.lua @@ -1,311 +1,7 @@ -wrap = {} +local cwrap = {} -dofile(debug.getinfo(1).source:gsub('init%.lua$', 'types.lua'):gsub('^@', '')) - -local CInterface = {} -wrap.CInterface = CInterface - -function CInterface.new() - self = {} - self.txt = {} - self.registry = {} - self.argtypes = wrap.argtypes - setmetatable(self, {__index=CInterface}) - return self -end - -function CInterface:luaname2wrapname(name) - return string.format("wrapper_%s", name) -end - -function CInterface:print(str) - table.insert(self.txt, str) -end - -function CInterface:wrap(luaname, ...) - local txt = self.txt - local varargs = {...} - - assert(#varargs > 0 and #varargs % 2 == 0, 'must provide both the C function name and the corresponding arguments') - - -- add function to the registry - table.insert(self.registry, {name=luaname, wrapname=self:luaname2wrapname(luaname)}) - - table.insert(txt, string.format("static int %s(lua_State *L)", self:luaname2wrapname(luaname))) - table.insert(txt, "{") - table.insert(txt, "int narg = lua_gettop(L);") - - if #varargs == 2 then - local cfuncname = varargs[1] - local args = varargs[2] - - local helpargs, cargs, argcreturned = self:__writeheaders(txt, args) - self:__writechecks(txt, args) - - table.insert(txt, 'else') - table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(helpargs, ' '))) - - self:__writecall(txt, args, cfuncname, cargs, argcreturned) - else - local allcfuncname = {} - local allargs = {} - local allhelpargs = {} - local allcargs = {} - local allargcreturned = {} - - table.insert(txt, "int argset = 0;") - - for k=1,#varargs/2 do - allcfuncname[k] = varargs[(k-1)*2+1] - allargs[k] = varargs[(k-1)*2+2] - end - - local argoffset = 0 - for k=1,#varargs/2 do - allhelpargs[k], allcargs[k], allargcreturned[k] = self:__writeheaders(txt, allargs[k], argoffset) - argoffset = argoffset + #allargs[k] - end - - for k=1,#varargs/2 do - self:__writechecks(txt, allargs[k], k) - end - - table.insert(txt, 'else') - local allconcathelpargs = {} - for k=1,#varargs/2 do - table.insert(allconcathelpargs, table.concat(allhelpargs[k], ' ')) - end - table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(allconcathelpargs, ' | '))) - - for k=1,#varargs/2 do - if k == 1 then - table.insert(txt, string.format('if(argset == %d)', k)) - else - table.insert(txt, string.format('else if(argset == %d)', k)) - end - table.insert(txt, '{') - self:__writecall(txt, allargs[k], allcfuncname[k], allcargs[k], allargcreturned[k]) - table.insert(txt, '}') - end - - table.insert(txt, 'return 0;') - end - - table.insert(txt, '}') - table.insert(txt, '') -end - -function CInterface:register(name) - local txt = self.txt - table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name)) - for _,reg in ipairs(self.registry) do - table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname)) - end - table.insert(txt, '{NULL, NULL}') - table.insert(txt, '};') - table.insert(txt, '') - self.registry = {} -end - -function CInterface:clearhistory() - self.txt = {} - self.registry = {} -end - -function CInterface:tostring() - return table.concat(self.txt, '\n') -end - -function CInterface:tofile(filename) - local f = io.open(filename, 'w') - f:write(table.concat(self.txt, '\n')) - f:close() -end - -local function bit(p) - return 2 ^ (p - 1) -- 1-based indexing -end - -local function hasbit(x, p) - return x % (p + p) >= p -end - -local function beautify(txt) - local indent = 0 - for i=1,#txt do - if txt[i]:match('}') then - indent = indent - 2 - end - if indent > 0 then - txt[i] = string.rep(' ', indent) .. txt[i] - end - if txt[i]:match('{') then - indent = indent + 2 - end - end -end - -local function tableinsertcheck(tbl, stuff) - if stuff and not stuff:match('^%s*$') then - table.insert(tbl, stuff) - end -end - -function CInterface:__writeheaders(txt, args, argoffset) - local argtypes = self.argtypes - local helpargs = {} - local cargs = {} - local argcreturned - argoffset = argoffset or 0 - - for i,arg in ipairs(args) do - arg.i = i+argoffset - arg.args = args -- in case we want to do stuff depending on other args - assert(argtypes[arg.name], 'unknown type ' .. arg.name) - setmetatable(arg, {__index=argtypes[arg.name]}) - arg.__metatable = argtypes[arg.name] - tableinsertcheck(txt, arg:declare()) - local helpname = arg:helpname() - if arg.returned then - helpname = string.format('*%s*', helpname) - end - if arg.invisible and arg.default == nil then - error('Invisible arguments must have a default! How could I guess how to initialize it?') - end - if arg.default ~= nil then - if not arg.invisible then - table.insert(helpargs, string.format('[%s]', helpname)) - end - elseif not arg.creturned then - table.insert(helpargs, helpname) - end - if arg.creturned then - if argcreturned then - error('A C function can only return one argument!') - end - if arg.default ~= nil then - error('Obviously, an "argument" returned by a C function cannot have a default value') - end - if arg.returned then - error('Options "returned" and "creturned" are incompatible') - end - argcreturned = arg - else - table.insert(cargs, arg:carg()) - end - end - - return helpargs, cargs, argcreturned -end - -function CInterface:__writechecks(txt, args, argset) - local argtypes = self.argtypes - - local multiargset = argset - argset = argset or 1 - - local nopt = 0 - for i,arg in ipairs(args) do - if arg.default ~= nil and not arg.invisible then - nopt = nopt + 1 - end - end - - for variant=0,math.pow(2, nopt)-1 do - local opt = 0 - local currentargs = {} - local optargs = {} - local hasvararg = false - for i,arg in ipairs(args) do - if arg.invisible then - table.insert(optargs, arg) - elseif arg.default ~= nil then - opt = opt + 1 - if hasbit(variant, bit(opt)) then - table.insert(currentargs, arg) - else - table.insert(optargs, arg) - end - elseif not arg.creturned then - table.insert(currentargs, arg) - end - end - - for _,arg in ipairs(args) do - if arg.vararg then - if hasvararg then - error('Only one argument can be a "vararg"!') - end - hasvararg = true - end - end - - if hasvararg and not currentargs[#currentargs].vararg then - error('Only the last argument can be a "vararg"') - end - - local compop - if hasvararg then - compop = '>=' - else - compop = '==' - end - - if variant == 0 and argset == 1 then - table.insert(txt, string.format('if(narg %s %d', compop, #currentargs)) - else - table.insert(txt, string.format('else if(narg %s %d', compop, #currentargs)) - end - - for stackidx, arg in ipairs(currentargs) do - table.insert(txt, string.format("&& %s", arg:check(stackidx))) - end - table.insert(txt, ')') - table.insert(txt, '{') - - if multiargset then - table.insert(txt, string.format('argset = %d;', argset)) - end - - for stackidx, arg in ipairs(currentargs) do - tableinsertcheck(txt, arg:read(stackidx)) - end - - for _,arg in ipairs(optargs) do - tableinsertcheck(txt, arg:init()) - end - - table.insert(txt, '}') - - end -end - -function CInterface:__writecall(txt, args, cfuncname, cargs, argcreturned) - local argtypes = self.argtypes - - for _,arg in ipairs(args) do - tableinsertcheck(txt, arg:precall()) - end - - if argcreturned then - table.insert(txt, string.format('%s = %s(%s);', argtypes[argcreturned.name].creturn(argcreturned), cfuncname, table.concat(cargs, ','))) - else - table.insert(txt, string.format('%s(%s);', cfuncname, table.concat(cargs, ','))) - end - - for _,arg in ipairs(args) do - tableinsertcheck(txt, arg:postcall()) - end - - local nret = 0 - if argcreturned then - nret = nret + 1 - end - for _,arg in ipairs(args) do - if arg.returned then - nret = nret + 1 - end - end - table.insert(txt, string.format('return %d;', nret)) -end +cwrap.types = require 'cwrap.types' +cwrap.CInterface = require 'cwrap.cinterface' +cwrap.CInterface.argtypes = cwrap.types +return cwrap diff --git a/types.lua b/types.lua index 059ec12..bc9a900 100644 --- a/types.lua +++ b/types.lua @@ -1,255 +1,4 @@ -wrap.argtypes = {} - -wrap.argtypes.Tensor = { - - helpname = function(arg) - if arg.dim then - return string.format("Tensor~%dD", arg.dim) - else - return "Tensor" - end - end, - - declare = function(arg) - local txt = {} - table.insert(txt, string.format("THTensor *arg%d = NULL;", arg.i)) - if arg.returned then - table.insert(txt, string.format("int arg%d_idx = 0;", arg.i)); - end - return table.concat(txt, '\n') - end, - - check = function(arg, idx) - if arg.dim then - return string.format("(arg%d = luaT_toudata(L, %d, torch_Tensor)) && (arg%d->nDimension == %d)", arg.i, idx, arg.i, arg.dim) - else - return string.format("(arg%d = luaT_toudata(L, %d, torch_Tensor))", arg.i, idx) - end - end, - - read = function(arg, idx) - if arg.returned then - return string.format("arg%d_idx = %d;", arg.i, idx) - end - end, - - init = function(arg) - if type(arg.default) == 'boolean' then - return string.format('arg%d = THTensor_(new)();', arg.i) - elseif type(arg.default) == 'number' then - return string.format('arg%d = %s;', arg.i, arg.args[arg.default]:carg()) - else - error('unknown default tensor type value') - end - end, - - carg = function(arg) - return string.format('arg%d', arg.i) - end, - - creturn = function(arg) - return string.format('arg%d', arg.i) - end, - - precall = function(arg) - local txt = {} - if arg.default and arg.returned then - table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg - table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) - table.insert(txt, string.format('else')) - if type(arg.default) == 'boolean' then -- boolean: we did a new() - table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_Tensor);', arg.i)) - else -- otherwise: point on default tensor --> retain - table.insert(txt, string.format('{')) - table.insert(txt, string.format('THTensor_(retain)(arg%d);', arg.i)) -- so we need a retain - table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_Tensor);', arg.i)) - table.insert(txt, string.format('}')) - end - elseif arg.default then - -- we would have to deallocate the beast later if we did a new - -- unlikely anyways, so i do not support it for now - if type(arg.default) == 'boolean' then - error('a tensor cannot be optional if not returned') - end - elseif arg.returned then - table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) - end - return table.concat(txt, '\n') - end, - - postcall = function(arg) - local txt = {} - if arg.creturned then - -- this next line is actually debatable - table.insert(txt, string.format('THTensor_(retain)(arg%d);', arg.i)) - table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_Tensor);', arg.i)) - end - return table.concat(txt, '\n') - end -} - -wrap.argtypes.IndexTensor = { - - helpname = function(arg) - return "LongTensor" - end, - - declare = function(arg) - local txt = {} - table.insert(txt, string.format("THLongTensor *arg%d = NULL;", arg.i)) - if arg.returned then - table.insert(txt, string.format("int arg%d_idx = 0;", arg.i)); - end - return table.concat(txt, '\n') - end, - - check = function(arg, idx) - return string.format('(arg%d = luaT_toudata(L, %d, "torch.LongTensor"))', arg.i, idx) - end, - - read = function(arg, idx) - local txt = {} - if not arg.noreadadd then - table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, -1);", arg.i, arg.i)); - end - if arg.returned then - table.insert(txt, string.format("arg%d_idx = %d;", arg.i, idx)) - end - return table.concat(txt, '\n') - end, - - init = function(arg) - return string.format('arg%d = THLongTensor_new();', arg.i) - end, - - carg = function(arg) - return string.format('arg%d', arg.i) - end, - - creturn = function(arg) - return string.format('arg%d', arg.i) - end, - - precall = function(arg) - local txt = {} - if arg.default and arg.returned then - table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg - table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) - table.insert(txt, string.format('else')) -- means we did a new() - table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.LongTensor");', arg.i)) - elseif arg.default then - error('a tensor cannot be optional if not returned') - elseif arg.returned then - table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) - end - return table.concat(txt, '\n') - end, - - postcall = function(arg) - local txt = {} - if arg.creturned or arg.returned then - table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, 1);", arg.i, arg.i)); - end - if arg.creturned then - -- this next line is actually debatable - table.insert(txt, string.format('THLongTensor_retain(arg%d);', arg.i)) - table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.LongTensor");', arg.i)) - end - return table.concat(txt, '\n') - end -} - -for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor", - "FloatTensor", "DoubleTensor"}) do - - wrap.argtypes[typename] = { - - helpname = function(arg) - if arg.dim then - return string.format('%s~%dD', typename, arg.dim) - else - return typename - end - end, - - declare = function(arg) - local txt = {} - table.insert(txt, string.format("TH%s *arg%d = NULL;", typename, arg.i)) - if arg.returned then - table.insert(txt, string.format("int arg%d_idx = 0;", arg.i)); - end - return table.concat(txt, '\n') - end, - - check = function(arg, idx) - if arg.dim then - return string.format('(arg%d = luaT_toudata(L, %d, "torch.%s")) && (arg%d->nDimension == %d)', arg.i, idx, typename, arg.i, arg.dim) - else - return string.format('(arg%d = luaT_toudata(L, %d, "torch.%s"))', arg.i, idx, typename) - end - end, - - read = function(arg, idx) - if arg.returned then - return string.format("arg%d_idx = %d;", arg.i, idx) - end - end, - - init = function(arg) - if type(arg.default) == 'boolean' then - return string.format('arg%d = TH%s_new();', arg.i, typename) - elseif type(arg.default) == 'number' then - return string.format('arg%d = %s;', arg.i, arg.args[arg.default]:carg()) - else - error('unknown default tensor type value') - end - end, - - carg = function(arg) - return string.format('arg%d', arg.i) - end, - - creturn = function(arg) - return string.format('arg%d', arg.i) - end, - - precall = function(arg) - local txt = {} - if arg.default and arg.returned then - table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg - table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) - table.insert(txt, string.format('else')) - if type(arg.default) == 'boolean' then -- boolean: we did a new() - table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.%s");', arg.i, typename)) - else -- otherwise: point on default tensor --> retain - table.insert(txt, string.format('{')) - table.insert(txt, string.format('TH%s_retain(arg%d);', typename, arg.i)) -- so we need a retain - table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.%s");', arg.i, typename)) - table.insert(txt, string.format('}')) - end - elseif arg.default then - -- we would have to deallocate the beast later if we did a new - -- unlikely anyways, so i do not support it for now - if type(arg.default) == 'boolean' then - error('a tensor cannot be optional if not returned') - end - elseif arg.returned then - table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i)) - end - return table.concat(txt, '\n') - end, - - postcall = function(arg) - local txt = {} - if arg.creturned then - -- this next line is actually debatable - table.insert(txt, string.format('TH%s_retain(arg%d);', typename, arg.i)) - table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.%s");', arg.i, typename)) - end - return table.concat(txt, '\n') - end - } -end +local argtypes = {} local function interpretdefaultvalue(arg) local default = arg.default @@ -274,7 +23,7 @@ local function interpretdefaultvalue(arg) end end -wrap.argtypes.index = { +argtypes.index = { helpname = function(arg) return "index" @@ -326,7 +75,7 @@ wrap.argtypes.index = { } for _,typename in ipairs({"real", "unsigned char", "char", "short", "int", "long", "float", "double"}) do - wrap.argtypes[typename] = { + argtypes[typename] = { helpname = function(arg) return typename @@ -378,9 +127,9 @@ for _,typename in ipairs({"real", "unsigned char", "char", "short", "int", "long } end -wrap.argtypes.byte = wrap.argtypes['unsigned char'] +argtypes.byte = argtypes['unsigned char'] -wrap.argtypes.boolean = { +argtypes.boolean = { helpname = function(arg) return "boolean" @@ -430,3 +179,5 @@ wrap.argtypes.boolean = { end end } + +return argtypes