diff --git a/CMakeLists.txt b/CMakeLists.txt index fb2de095..611258b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,8 +25,6 @@ IF(MSVC) ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) ENDIF(MSVC) -ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1) - # OpenMP support? SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?") IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) diff --git a/FFI.lua b/FFI.lua index ebdcd6d9..333e7b54 100644 --- a/FFI.lua +++ b/FFI.lua @@ -118,7 +118,7 @@ typedef struct THRealTensor end) -- faster apply (contiguous case) - if Tensor_type ~= 'torch.HalfTensor' or torch.hashalfmath() then + if Tensor_type ~= 'torch.HalfTensor' then local apply = Tensor.apply rawset(Tensor, "apply", diff --git a/Storage.c b/Storage.c index 28c4e87c..874838e8 100644 --- a/Storage.c +++ b/Storage.c @@ -7,3 +7,6 @@ #include "generic/Storage.c" #include "THGenerateAllTypes.h" + +#include "generic/Storage.c" +#include "THGenerateHalfType.h" diff --git a/Tensor.c b/Tensor.c index 4bfbc6ad..bf78d1aa 100644 --- a/Tensor.c +++ b/Tensor.c @@ -7,3 +7,6 @@ #include "generic/Tensor.c" #include "THGenerateAllTypes.h" + +#include "generic/Tensor.c" +#include "THGenerateHalfType.h" diff --git a/Tensor.lua b/Tensor.lua index 36307bd3..9a8215be 100644 --- a/Tensor.lua +++ b/Tensor.lua @@ -560,7 +560,7 @@ torch.permute = Tensor.permute for _,type in ipairs(types) do local metatable = torch.getmetatable('torch.' .. type .. 'Tensor') for funcname, func in pairs(Tensor) do - if funcname ~= 'totable' or type ~='Half' or torch.hashalfmath() then + if funcname ~= 'totable' or type ~='Half' then rawset(metatable, funcname, func) else local function Tensor__totable(self) diff --git a/TensorMath.lua b/TensorMath.lua index 546668b9..5b837e9f 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -1043,7 +1043,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b) {name="long",default=100}, {name="double",default=0}, {name="double",default=0}}) - + wrap("bhistc", cname("bhistc"), {{name=Tensor, default=true, returned=true}, @@ -1446,9 +1446,6 @@ void torch_TensorMath_init(lua_State *L) torch_IntTensorMath_init(L); torch_LongTensorMath_init(L); torch_FloatTensorMath_init(L); - #if TH_NATIVE_HALF - torch_HalfTensorMath_init(L); - #endif torch_DoubleTensorMath_init(L); luaT_setfuncs(L, torch_TensorMath__, 0); } diff --git a/generic/Storage.c b/generic/Storage.c index 0c8d727e..a6652a5d 100644 --- a/generic/Storage.c +++ b/generic/Storage.c @@ -131,10 +131,8 @@ static int torch_Storage_(copy)(lua_State *L) THStorage_(copyFloat)(storage, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) THStorage_(copyDouble)(storage, src); -#if TH_GENERIC_USE_HALF else if( (src = luaT_toudata(L, 2, "torch.HalfStorage")) ) THStorage_(copyHalf)(storage, src); -#endif else luaL_typerror(L, 2, "torch.*Storage"); lua_settop(L, 1); diff --git a/generic/Tensor.c b/generic/Tensor.c index 1d52a971..c2417fed 100644 --- a/generic/Tensor.c +++ b/generic/Tensor.c @@ -384,7 +384,7 @@ static int torch_Tensor_(select)(lua_State *L) return 1; } -#ifndef TH_GENERIC_NO_MATH +#ifndef TH_REAL_IS_HALF static int torch_Tensor_(indexSelect)(lua_State *L) { int narg = lua_gettop(L); @@ -677,10 +677,8 @@ static int torch_Tensor_(copy)(lua_State *L) THTensor_(copyFloat)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) THTensor_(copyDouble)(tensor, src); -#if TH_GENERIC_USE_HALF else if( (src = luaT_toudata(L, 2, "torch.HalfTensor")) ) THTensor_(copyHalf)(tensor, src); -#endif else luaL_typerror(L, 2, "torch.*Tensor"); lua_settop(L, 1); @@ -755,13 +753,11 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyDouble)(tensor, src); THTensor_(free)(tensor); -#if TH_GENERIC_USE_HALF } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyHalf)(tensor, src); THTensor_(free)(tensor); -#endif } else { luaL_typerror(L, 3, "torch.*Tensor"); } @@ -867,10 +863,8 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor_(copyFloat)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) { THTensor_(copyDouble)(tensor, src); -#if TH_GENERIC_USE_HALF } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) { THTensor_(copyHalf)(tensor, src); -#endif } else { luaL_typerror(L, 3, "torch.*Tensor"); } @@ -1165,7 +1159,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index THArgCheck(0, index, "expecting number"); } -#ifndef TH_GENERIC_NO_MATH +#ifndef TH_REAL_IS_HALF static int torch_Tensor_(apply)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); @@ -1322,7 +1316,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = { {"narrow", torch_Tensor_(narrow)}, {"sub", torch_Tensor_(sub)}, {"select", torch_Tensor_(select)}, -#ifndef TH_GENERIC_NO_MATH +#ifndef TH_REAL_IS_HALF {"index", torch_Tensor_(indexSelect)}, {"indexCopy", torch_Tensor_(indexCopy)}, {"indexAdd", torch_Tensor_(indexAdd)}, @@ -1340,7 +1334,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = { {"isSize", torch_Tensor_(isSize)}, {"nElement", torch_Tensor_(nElement)}, {"copy", torch_Tensor_(copy)}, -#ifndef TH_GENERIC_NO_MATH +#ifndef TH_REAL_IS_HALF {"apply", torch_Tensor_(apply)}, {"map", torch_Tensor_(map)}, {"map2", torch_Tensor_(map2)}, @@ -1358,7 +1352,7 @@ void torch_Tensor_(init)(lua_State *L) torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); luaT_setfuncs(L, torch_Tensor_(_), 0); lua_pop(L, 1); -#ifndef TH_GENERIC_NO_MATH +#ifndef TH_REAL_IS_HALF THVector_(vectorDispatchInit)(); #endif diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c index c722e889..e131c573 100644 --- a/generic/TensorOperator.c +++ b/generic/TensorOperator.c @@ -2,9 +2,6 @@ #define TH_GENERIC_FILE "generic/TensorOperator.c" #else -/* Tensor math may be disabled for certain types, e.g. 'half' */ -#ifndef TH_GENERIC_NO_MATH - static int torch_TensorOperator_(__add__)(lua_State *L) { THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor); @@ -190,6 +187,5 @@ void torch_TensorOperator_(init)(lua_State *L) luaT_setfuncs(L, torch_TensorOperator_(_), 0); lua_pop(L, 1); } -#endif #endif diff --git a/init.c b/init.c index 0c413f90..3bdac176 100644 --- a/init.c +++ b/init.c @@ -35,28 +35,9 @@ extern void torch_LongTensorOperator_init(lua_State *L); extern void torch_FloatTensorOperator_init(lua_State *L); extern void torch_DoubleTensorOperator_init(lua_State *L); -#if TH_NATIVE_HALF -extern void torch_HalfTensorOperator_init(lua_State *L); -#endif extern void torch_TensorMath_init(lua_State *L); -static int torch_hashalfmath(lua_State *L) { - lua_pushboolean(L, TH_NATIVE_HALF); - return 1; -} - -static void torch_half_init(lua_State *L) -{ - const struct luaL_Reg half_funcs__ [] = { - {"hashalfmath", torch_hashalfmath}, - {NULL, NULL} - }; - luaT_setfuncs(L, half_funcs__, 0); - - lua_pushboolean(L, 1); - lua_setfield(L, -2, "hasHalf"); -} LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L); @@ -69,7 +50,6 @@ int luaopen_libtorch(lua_State *L) torch_utils_init(L); torch_File_init(L); - torch_half_init(L); torch_ByteStorage_init(L); torch_CharStorage_init(L); @@ -97,10 +77,6 @@ int luaopen_libtorch(lua_State *L) torch_FloatTensorOperator_init(L); torch_DoubleTensorOperator_init(L); -#if TH_NATIVE_HALF - torch_HalfTensorOperator_init(L); -#endif - torch_Timer_init(L); torch_DiskFile_init(L); torch_PipeFile_init(L); diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt index 64d7c98c..0b5f1667 100644 --- a/lib/TH/CMakeLists.txt +++ b/lib/TH/CMakeLists.txt @@ -339,6 +339,7 @@ INSTALL(FILES THFilePrivate.h ${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h THGenerateAllTypes.h + THGenerateHalfType.h THGenerateFloatTypes.h THGenerateIntTypes.h THLapack.h diff --git a/lib/TH/THDiskFile.c b/lib/TH/THDiskFile.c index b23a7974..01b19513 100644 --- a/lib/TH/THDiskFile.c +++ b/lib/TH/THDiskFile.c @@ -348,18 +348,14 @@ READ_WRITE_METHODS(int, Int, int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++, int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++) -/*READ_WRITE_METHODS(long, Long, - int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++, - int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++)*/ - READ_WRITE_METHODS(float, Float, int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++, int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++) -#if TH_GENERIC_USE_HALF + READ_WRITE_METHODS(THHalf, Half, float buf; int ret = fscanf(dfself->handle, "%g", &buf); if(ret <= 0) break; else { data[i]= TH_float2half(buf); nread++; }, int ret = fprintf(dfself->handle, "%.9g", TH_half2float(data[i])); if(ret <= 0) break; else nwrite++) -#endif + READ_WRITE_METHODS(double, Double, int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++, int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++) @@ -622,9 +618,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_readLong, THDiskFile_readFloat, THDiskFile_readDouble, -#if TH_GENERIC_USE_HALF THDiskFile_readHalf, -#endif THDiskFile_readString, THDiskFile_writeByte, @@ -634,9 +628,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_writeLong, THDiskFile_writeFloat, THDiskFile_writeDouble, -#if TH_GENERIC_USE_HALF THDiskFile_writeHalf, -#endif THDiskFile_writeString, THDiskFile_synchronize, @@ -740,9 +732,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_readLong, THDiskFile_readFloat, THDiskFile_readDouble, -#if TH_GENERIC_USE_HALF THDiskFile_readHalf, -#endif THDiskFile_readString, THDiskFile_writeByte, @@ -752,9 +742,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) THDiskFile_writeLong, THDiskFile_writeFloat, THDiskFile_writeDouble, -#if TH_GENERIC_USE_HALF THDiskFile_writeHalf, -#endif THDiskFile_writeString, THDiskFile_synchronize, diff --git a/lib/TH/THFile.c b/lib/TH/THFile.c index 58819191..3717b7b5 100644 --- a/lib/TH/THFile.c +++ b/lib/TH/THFile.c @@ -19,9 +19,7 @@ IMPLEMENT_THFILE_RW(Int, int) IMPLEMENT_THFILE_RW(Long, long) IMPLEMENT_THFILE_RW(Float, float) IMPLEMENT_THFILE_RW(Double, double) -#if TH_GENERIC_USE_HALF IMPLEMENT_THFILE_RW(Half, THHalf) -#endif size_t THFile_readStringRaw(THFile *self, const char *format, char **str_) { @@ -136,9 +134,7 @@ IMPLEMENT_THFILE_SCALAR(Int, int) IMPLEMENT_THFILE_SCALAR(Long, long) IMPLEMENT_THFILE_SCALAR(Float, float) IMPLEMENT_THFILE_SCALAR(Double, double) -#if TH_GENERIC_USE_HALF IMPLEMENT_THFILE_SCALAR(Half, THHalf) -#endif #define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \ size_t THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \ @@ -158,6 +154,4 @@ IMPLEMENT_THFILE_STORAGE(Int, int) IMPLEMENT_THFILE_STORAGE(Long, long) IMPLEMENT_THFILE_STORAGE(Float, float) IMPLEMENT_THFILE_STORAGE(Double, double) -#if TH_GENERIC_USE_HALF IMPLEMENT_THFILE_STORAGE(Half, THHalf) -#endif diff --git a/lib/TH/THFile.h b/lib/TH/THFile.h index 2173270e..e097bdf3 100644 --- a/lib/TH/THFile.h +++ b/lib/TH/THFile.h @@ -74,14 +74,12 @@ TH_API size_t THFile_writeFloatRaw(THFile *self, float *data, size_t n); TH_API size_t THFile_writeDoubleRaw(THFile *self, double *data, size_t n); TH_API size_t THFile_writeStringRaw(THFile *self, const char *str, size_t size); -#if TH_GENERIC_USE_HALF TH_API THHalf THFile_readHalfScalar(THFile *self); TH_API void THFile_writeHalfScalar(THFile *self, THHalf scalar); TH_API size_t THFile_readHalf(THFile *self, THHalfStorage *storage); TH_API size_t THFile_writeHalf(THFile *self, THHalfStorage *storage); TH_API size_t THFile_readHalfRaw(THFile *self, THHalf* data, size_t size); TH_API size_t THFile_writeHalfRaw(THFile *self, THHalf* data, size_t size); -#endif TH_API void THFile_synchronize(THFile *self); TH_API void THFile_seek(THFile *self, size_t position); diff --git a/lib/TH/THFilePrivate.h b/lib/TH/THFilePrivate.h index 0de91635..55169c3b 100644 --- a/lib/TH/THFilePrivate.h +++ b/lib/TH/THFilePrivate.h @@ -1,8 +1,6 @@ #include "THGeneral.h" -#if TH_GENERIC_USE_HALF -# include "THHalf.h" -#endif +#include "THHalf.h" struct THFile__ @@ -30,9 +28,7 @@ struct THFileVTable size_t (*readLong)(THFile *self, long *data, size_t n); size_t (*readFloat)(THFile *self, float *data, size_t n); size_t (*readDouble)(THFile *self, double *data, size_t n); -#if TH_GENERIC_USE_HALF size_t (*readHalf)(THFile *self, THHalf *data, size_t n); -#endif size_t (*readString)(THFile *self, const char *format, char **str_); size_t (*writeByte)(THFile *self, unsigned char *data, size_t n); @@ -42,9 +38,7 @@ struct THFileVTable size_t (*writeLong)(THFile *self, long *data, size_t n); size_t (*writeFloat)(THFile *self, float *data, size_t n); size_t (*writeDouble)(THFile *self, double *data, size_t n); -#if TH_GENERIC_USE_HALF size_t (*writeHalf)(THFile *self, THHalf *data, size_t n); -#endif size_t (*writeString)(THFile *self, const char *str, size_t size); void (*synchronize)(THFile *self); diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index 6873db66..bc7e4482 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -46,14 +46,6 @@ #define TH_INDEX_BASE 1 #endif -#ifndef TH_GENERIC_USE_HALF -# define TH_GENERIC_USE_HALF 0 -#endif - -#ifndef TH_NATIVE_HALF -# define TH_NATIVE_HALF 0 -#endif - typedef void (*THErrorHandlerFunction)(const char *msg, void *data); typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data); diff --git a/lib/TH/THGenerateAllTypes.h b/lib/TH/THGenerateAllTypes.h index 4342cd48..4a770812 100644 --- a/lib/TH/THGenerateAllTypes.h +++ b/lib/TH/THGenerateAllTypes.h @@ -8,7 +8,6 @@ #define THInf UCHAR_MAX #define TH_REAL_IS_BYTE #line 1 TH_GENERIC_FILE -/*#line 1 "THByteStorage.h"*/ #include TH_GENERIC_FILE #undef real #undef accreal @@ -94,24 +93,4 @@ #undef THInf #undef TH_REAL_IS_DOUBLE -#if TH_GENERIC_USE_HALF -#include "THHalf.h" -#define real THHalf -#define accreal float -#define Real Half -#define THInf TH_HALF_MAX -#define TH_REAL_IS_HALF -#if !TH_NATIVE_HALF -# define TH_GENERIC_NO_MATH 1 -#endif -#line 1 TH_GENERIC_FILE -#include TH_GENERIC_FILE -#undef real -#undef accreal -#undef Real -#undef THInf -#undef TH_REAL_IS_HALF -#undef TH_GENERIC_NO_MATH -#endif - #undef TH_GENERIC_FILE diff --git a/lib/TH/THGenerateHalfType.h b/lib/TH/THGenerateHalfType.h new file mode 100644 index 00000000..9acc5345 --- /dev/null +++ b/lib/TH/THGenerateHalfType.h @@ -0,0 +1,19 @@ +#ifndef TH_GENERIC_FILE +#error "You must define TH_GENERIC_FILE before including THGenerateHalfType.h" +#endif + +#include "THHalf.h" +#define real THHalf +#define accreal float +#define Real Half +#define THInf TH_HALF_MAX +#define TH_REAL_IS_HALF +#line 1 TH_GENERIC_FILE +#include TH_GENERIC_FILE +#undef real +#undef accreal +#undef Real +#undef THInf +#undef TH_REAL_IS_HALF + +#undef TH_GENERIC_FILE diff --git a/lib/TH/THHalf.h b/lib/TH/THHalf.h index 930477d0..0549d214 100644 --- a/lib/TH/THHalf.h +++ b/lib/TH/THHalf.h @@ -6,18 +6,18 @@ /* Neither built-in nor included from Cutorch, use our definition lifted from CUDA */ #if defined(__GNUC__) -#define __align__(n) __attribute__((aligned(n))) +#define __thalign__(n) __attribute__((aligned(n))) #elif defined(_WIN32) -#define __align__(n) __declspec(align(n)) +#define __thalign__(n) __declspec(align(n)) #else -#define __align__(n) +#define __thalign__(n) #endif -typedef struct __align__(2){ +typedef struct __thalign__(2){ unsigned short x; } __THHalf; -typedef struct __align__(4) { +typedef struct __thalign__(4) { unsigned int x; } __THHalf2; @@ -36,4 +36,5 @@ TH_API float TH_half2float(THHalf a); #define TH_HALF_MAX TH_HALF_BITS_TO_LITERAL(0x7BFF) +#undef __thalign__ #endif diff --git a/lib/TH/THMemoryFile.c b/lib/TH/THMemoryFile.c index 22e28801..ecce6e1b 100644 --- a/lib/TH/THMemoryFile.c +++ b/lib/TH/THMemoryFile.c @@ -332,23 +332,18 @@ READ_WRITE_METHODS(int, Int, nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%d", data[i]), 1) -/*READ_WRITE_METHODS(long, Long, - int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%ld%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, - nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%ld", data[i]), - 1)*/ - READ_WRITE_METHODS(float, Float, int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.9g", data[i]), 1) -#if TH_GENERIC_USE_HALF + READ_WRITE_METHODS(THHalf, Half, int nByteRead_; float buf; \ int ret = sscanf(mfself->storage->data+mfself->position, "%g%n", &buf, &nByteRead_); \ data[i] = TH_float2half(buf); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.9g", TH_half2float(data[i])), 1) -#endif + READ_WRITE_METHODS(double, Double, int nByteRead_; int ret = sscanf(mfself->storage->data+mfself->position, "%lg%n", &data[i], &nByteRead_); nByteRead = nByteRead_; if(ret <= 0) break; else nread++, nByteWritten = snprintf(mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.17g", data[i]), @@ -628,9 +623,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) THMemoryFile_readLong, THMemoryFile_readFloat, THMemoryFile_readDouble, -#if TH_GENERIC_USE_HALF THMemoryFile_readHalf, -#endif THMemoryFile_readString, THMemoryFile_writeByte, @@ -640,9 +633,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) THMemoryFile_writeLong, THMemoryFile_writeFloat, THMemoryFile_writeDouble, -#if TH_GENERIC_USE_HALF THMemoryFile_writeHalf, -#endif THMemoryFile_writeString, THMemoryFile_synchronize, diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c index d18488ef..bb63a435 100644 --- a/lib/TH/THStorage.c +++ b/lib/TH/THStorage.c @@ -4,5 +4,11 @@ #include "generic/THStorage.c" #include "THGenerateAllTypes.h" +#include "generic/THStorage.c" +#include "THGenerateHalfType.h" + #include "generic/THStorageCopy.c" #include "THGenerateAllTypes.h" + +#include "generic/THStorageCopy.c" +#include "THGenerateHalfType.h" diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h index 36ed507b..9565e10a 100644 --- a/lib/TH/THStorage.h +++ b/lib/TH/THStorage.h @@ -14,7 +14,13 @@ #include "generic/THStorage.h" #include "THGenerateAllTypes.h" +#include "generic/THStorage.h" +#include "THGenerateHalfType.h" + #include "generic/THStorageCopy.h" #include "THGenerateAllTypes.h" +#include "generic/THStorageCopy.h" +#include "THGenerateHalfType.h" + #endif diff --git a/lib/TH/THTensor.c b/lib/TH/THTensor.c index 2878fc99..37071df8 100644 --- a/lib/TH/THTensor.c +++ b/lib/TH/THTensor.c @@ -11,9 +11,15 @@ #include "generic/THTensor.c" #include "THGenerateAllTypes.h" +#include "generic/THTensor.c" +#include "THGenerateHalfType.h" + #include "generic/THTensorCopy.c" #include "THGenerateAllTypes.h" +#include "generic/THTensorCopy.c" +#include "THGenerateHalfType.h" + #include "generic/THTensorRandom.c" #include "THGenerateAllTypes.h" diff --git a/lib/TH/THTensor.h b/lib/TH/THTensor.h index 6eddf9c7..a155efde 100644 --- a/lib/TH/THTensor.h +++ b/lib/TH/THTensor.h @@ -16,9 +16,15 @@ typedef struct { #include "generic/THTensor.h" #include "THGenerateAllTypes.h" +#include "generic/THTensor.h" +#include "THGenerateHalfType.h" + #include "generic/THTensorCopy.h" #include "THGenerateAllTypes.h" +#include "generic/THTensorCopy.h" +#include "THGenerateHalfType.h" + #include "THTensorMacros.h" /* random numbers */ diff --git a/lib/TH/generic/THBlas.c b/lib/TH/generic/THBlas.c index 8b3a4032..371df4d4 100644 --- a/lib/TH/generic/THBlas.c +++ b/lib/TH/generic/THBlas.c @@ -2,7 +2,6 @@ #define TH_GENERIC_FILE "generic/THBlas.c" #else -# ifndef TH_GENERIC_NO_MATH #ifdef BLAS_F2C # define ffloat double @@ -404,5 +403,5 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, } } } -# endif /* TH_GENERIC_NO_MATH */ + #endif diff --git a/lib/TH/generic/THBlas.h b/lib/TH/generic/THBlas.h index a49d79cd..9e14f5a8 100644 --- a/lib/TH/generic/THBlas.h +++ b/lib/TH/generic/THBlas.h @@ -1,7 +1,7 @@ #ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "generic/THBlas.h" #else -# ifndef TH_GENERIC_NO_MATH + /* Level 1 */ TH_API void THBlas_(swap)(long n, real *x, long incx, real *y, long incy); TH_API void THBlas_(scal)(long n, real a, real *x, long incx); @@ -15,5 +15,5 @@ TH_API void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y /* Level 3 */ TH_API void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, real *a, long lda, real *b, long ldb, real beta, real *c, long ldc); -# endif + #endif diff --git a/lib/TH/generic/THStorageCopy.c b/lib/TH/generic/THStorageCopy.c index bc2d7e70..ce4b57ea 100644 --- a/lib/TH/generic/THStorageCopy.c +++ b/lib/TH/generic/THStorageCopy.c @@ -58,9 +58,7 @@ IMPLEMENT_THStorage_COPY(Int) IMPLEMENT_THStorage_COPY(Long) IMPLEMENT_THStorage_COPY(Float) IMPLEMENT_THStorage_COPY(Double) -#if TH_GENERIC_USE_HALF IMPLEMENT_THStorage_COPY_FROM_HALF(Half) -#endif #else /* only allow pass-through for Half */ IMPLEMENT_THStorage_COPY_TO_FROM_HALF(Half) diff --git a/lib/TH/generic/THStorageCopy.h b/lib/TH/generic/THStorageCopy.h index bb2f406b..ce8a2a69 100644 --- a/lib/TH/generic/THStorageCopy.h +++ b/lib/TH/generic/THStorageCopy.h @@ -13,8 +13,6 @@ TH_API void THStorage_(copyInt)(THStorage *storage, struct THIntStorage *src); TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src); TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src); TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src); -#if TH_GENERIC_USE_HALF TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src); -#endif #endif diff --git a/lib/TH/generic/THTensorConv.c b/lib/TH/generic/THTensorConv.c index aa864ada..1e219915 100644 --- a/lib/TH/generic/THTensorConv.c +++ b/lib/TH/generic/THTensorConv.c @@ -2,9 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorConv.c" #else -/* Tensor math may be disabled for certain types, e.g. 'TH_half' */ -#ifndef TH_GENERIC_NO_MATH - /* 2D Input, 2D kernel : convolve given image with the given kernel. */ @@ -1958,4 +1955,3 @@ void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT THTensor_(free)(kernel); } #endif -#endif diff --git a/lib/TH/generic/THTensorConv.h b/lib/TH/generic/THTensorConv.h index 2248d458..79866f39 100644 --- a/lib/TH/generic/THTensorConv.h +++ b/lib/TH/generic/THTensorConv.h @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorConv.h" #else -#ifndef TH_GENERIC_NO_MATH - TH_API void THTensor_(validXCorr2Dptr)(real *r_, real alpha, real *t_, long ir, long ic, @@ -67,7 +65,7 @@ TH_API void THTensor_(fullConv3Dptr)(real *r_, long st, long sr, long sc); TH_API void THTensor_(validXCorr3DRevptr)(real *r_, - real alpha, + real alpha, real *t_, long it, long ir, long ic, real *k_, long kt, long kr, long kc, long st, long sr, long sc); @@ -79,4 +77,3 @@ TH_API void THTensor_(conv3Dmul)(THTensor *r_, real beta, real alpha, THTensor * TH_API void THTensor_(conv3Dcmul)(THTensor *r_, real beta, real alpha, THTensor *t_, THTensor *k_, long sdepth, long srow, long scol, const char *vf, const char *xc); #endif -#endif diff --git a/lib/TH/generic/THTensorCopy.c b/lib/TH/generic/THTensorCopy.c index 2ce5e1fa..d7228f68 100644 --- a/lib/TH/generic/THTensorCopy.c +++ b/lib/TH/generic/THTensorCopy.c @@ -33,9 +33,7 @@ IMPLEMENT_THTensor_COPY(Int, int) IMPLEMENT_THTensor_COPY(Long, long) IMPLEMENT_THTensor_COPY(Float, float) IMPLEMENT_THTensor_COPY(Double, double) -#if TH_GENERIC_USE_HALF IMPLEMENT_THTensor_COPY_FROM_HALF(Half, THHalf) -#endif #else /* only allow pass-through for Half */ IMPLEMENT_THTensor_COPY(Half, THHalf) diff --git a/lib/TH/generic/THTensorCopy.h b/lib/TH/generic/THTensorCopy.h index 8a0abcff..b9e5bfc9 100644 --- a/lib/TH/generic/THTensorCopy.h +++ b/lib/TH/generic/THTensorCopy.h @@ -12,7 +12,6 @@ TH_API void THTensor_(copyInt)(THTensor *tensor, struct THIntTensor *src); TH_API void THTensor_(copyLong)(THTensor *tensor, struct THLongTensor *src); TH_API void THTensor_(copyFloat)(THTensor *tensor, struct THFloatTensor *src); TH_API void THTensor_(copyDouble)(THTensor *tensor, struct THDoubleTensor *src); -#if TH_GENERIC_USE_HALF TH_API void THTensor_(copyHalf)(THTensor *tensor, struct THHalfTensor *src); -#endif + #endif diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 592413cc..d8251b16 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -2,7 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorMath.c" #else -#ifndef TH_GENERIC_NO_MATH #define TH_OMP_OVERHEAD_THRESHOLD 100000 void THTensor_(fill)(THTensor *r_, real value) @@ -2546,9 +2545,9 @@ void THTensor_(histc)(THTensor *hist, THTensor *tensor, long nbins, real minvalu } void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, long nbins, real minvalue, real maxvalue) -{ +{ THArgCheck(THTensor_(nDimension)(tensor) < 3, 2, "invalid dimension %d, the input must be a 2d tensor", THTensor_(nDimension)(tensor)); - + int dimension = 1; THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(tensor), 2, "invalid dimension %d", dimension + TH_INDEX_BASE); @@ -2575,7 +2574,7 @@ void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, long nbins, real minval TH_TENSOR_DIM_APPLY2(real, tensor, real, hist, dimension, long i; for(i = 0; i < tensor_size; i++) - { + { if(tensor_data[i*tensor_stride] >= minval && tensor_data[i*tensor_stride] <= maxval) { const int bin = (int)((tensor_data[i*tensor_stride]-minval) / (maxval-minval) * nbins); hist_data[THMin(bin, nbins-1)] += 1; @@ -2586,5 +2585,4 @@ void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, long nbins, real minval #endif /* floating point only part */ #undef IS_NONZERO -#endif /* TH_GENERIC_NO_MATH */ #endif diff --git a/lib/TH/generic/THTensorMath.h b/lib/TH/generic/THTensorMath.h index 476c2b88..c656dfdb 100644 --- a/lib/TH/generic/THTensorMath.h +++ b/lib/TH/generic/THTensorMath.h @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorMath.h" #else -#ifndef TH_GENERIC_NO_MATH - TH_API void THTensor_(fill)(THTensor *r_, real value); TH_API void THTensor_(zero)(THTensor *r_); @@ -183,5 +181,4 @@ TH_API int THTensor_(logicalany)(THTensor *self); #endif /* TH_REAL_IS_BYTE */ -#endif /* TH_GENERIC_NO_MATH */ #endif diff --git a/lib/TH/generic/THTensorRandom.c b/lib/TH/generic/THTensorRandom.c index 18b0471c..514d3dd2 100644 --- a/lib/TH/generic/THTensorRandom.c +++ b/lib/TH/generic/THTensorRandom.c @@ -2,9 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorRandom.c" #else -/* Tensor math may be disabled for certain types, e.g. 'half' */ -#ifndef TH_GENERIC_NO_MATH - void THTensor_(random)(THTensor *self, THGenerator *_generator) { #if defined(TH_REAL_IS_BYTE) @@ -250,6 +247,4 @@ void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self) } #endif -#endif /* TH_GENERIC_NO_MATH */ - #endif diff --git a/lib/TH/generic/THTensorRandom.h b/lib/TH/generic/THTensorRandom.h index bb5ed915..d2051424 100644 --- a/lib/TH/generic/THTensorRandom.h +++ b/lib/TH/generic/THTensorRandom.h @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "generic/THTensorRandom.h" #else -#ifndef TH_GENERIC_NO_MATH - TH_API void THTensor_(random)(THTensor *self, THGenerator *_generator); TH_API void THTensor_(geometric)(THTensor *self, THGenerator *_generator, double p); TH_API void THTensor_(bernoulli)(THTensor *self, THGenerator *_generator, double p); @@ -25,4 +23,3 @@ TH_API void THTensor_(setRNGState)(THGenerator *_generator, THTensor *self); #endif #endif -#endif diff --git a/lib/TH/generic/THVector.h b/lib/TH/generic/THVector.h index eaeb008a..67fdcfa8 100644 --- a/lib/TH/generic/THVector.h +++ b/lib/TH/generic/THVector.h @@ -2,13 +2,11 @@ #define TH_GENERIC_FILE "generic/THVector.h" #else -#ifndef TH_GENERIC_NO_MATH TH_API void THVector_(fill)(real *x, const real c, const ptrdiff_t n); TH_API void THVector_(add)(real *y, const real *x, const real c, const ptrdiff_t n); TH_API void THVector_(diff)(real *z, const real *x, const real *y, const ptrdiff_t n); TH_API void THVector_(scale)(real *y, const real c, const ptrdiff_t n); TH_API void THVector_(mul)(real *y, const real *x, const ptrdiff_t n); -#endif /* Initialize the dispatch pointers */ TH_API void THVector_(vectorDispatchInit)(void); diff --git a/lib/TH/generic/THVectorDefault.c b/lib/TH/generic/THVectorDefault.c index 7554d45e..aabc16c5 100644 --- a/lib/TH/generic/THVectorDefault.c +++ b/lib/TH/generic/THVectorDefault.c @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "generic/THVectorDefault.c" #else -#ifndef TH_GENERIC_NO_MATH - void THVector_(fill_DEFAULT)(real *x, const real c, const ptrdiff_t n) { ptrdiff_t i = 0; @@ -82,5 +80,5 @@ void THVector_(mul_DEFAULT)(real *y, const real *x, const ptrdiff_t n) for(; i < n; i++) y[i] *= x[i]; } -# endif + #endif diff --git a/lib/TH/generic/THVectorDispatch.c b/lib/TH/generic/THVectorDispatch.c index a93587d1..2436a125 100644 --- a/lib/TH/generic/THVectorDispatch.c +++ b/lib/TH/generic/THVectorDispatch.c @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "generic/THVectorDispatch.c" #else -#ifndef TH_GENERIC_NO_MATH - /* For now there are only SIMD implementations for FLOAT and DOUBLE. * Hopefully in the future this can be made totally generic (e.g, there are SIMD implementations * for a lot of functions */ @@ -167,5 +165,5 @@ void THVector_(vectorDispatchInit)(void) INIT_DISPATCH_PTR(scale); INIT_DISPATCH_PTR(mul); } -#endif /* TH_GENERIC_NO_MATH */ + #endif