Skip to content

Commit

Permalink
Removing TH_GENERIC_USE_HALF, TH_NATIVE_HALF, TH_GENERIC_NO_MATH (rep…
Browse files Browse the repository at this point in the history
…laced where appropriate with TH_REAL_IS_HALF), removed half from THGenerateAllTypes, added an explicit THGenerateHalfType.h
  • Loading branch information
soumith committed Jan 1, 2017
1 parent 247f200 commit 1e86025
Show file tree
Hide file tree
Showing 39 changed files with 79 additions and 167 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ IF(MSVC)
ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1)
ENDIF(MSVC)

ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1)

# OpenMP support?
SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?")
IF (APPLE AND CMAKE_COMPILER_IS_GNUCC)
Expand Down
2 changes: 1 addition & 1 deletion FFI.lua
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ typedef struct THRealTensor
end)

-- faster apply (contiguous case)
if Tensor_type ~= 'torch.HalfTensor' or torch.hashalfmath() then
if Tensor_type ~= 'torch.HalfTensor' then
local apply = Tensor.apply
rawset(Tensor,
"apply",
Expand Down
3 changes: 3 additions & 0 deletions Storage.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@

#include "generic/Storage.c"
#include "THGenerateAllTypes.h"

#include "generic/Storage.c"
#include "THGenerateHalfType.h"
3 changes: 3 additions & 0 deletions Tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@

#include "generic/Tensor.c"
#include "THGenerateAllTypes.h"

#include "generic/Tensor.c"
#include "THGenerateHalfType.h"
2 changes: 1 addition & 1 deletion Tensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ torch.permute = Tensor.permute
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
for funcname, func in pairs(Tensor) do
if funcname ~= 'totable' or type ~='Half' or torch.hashalfmath() then
if funcname ~= 'totable' or type ~='Half' then
rawset(metatable, funcname, func)
else
local function Tensor__totable(self)
Expand Down
5 changes: 1 addition & 4 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
{name="long",default=100},
{name="double",default=0},
{name="double",default=0}})

wrap("bhistc",
cname("bhistc"),
{{name=Tensor, default=true, returned=true},
Expand Down Expand Up @@ -1446,9 +1446,6 @@ void torch_TensorMath_init(lua_State *L)
torch_IntTensorMath_init(L);
torch_LongTensorMath_init(L);
torch_FloatTensorMath_init(L);
#if TH_NATIVE_HALF
torch_HalfTensorMath_init(L);
#endif
torch_DoubleTensorMath_init(L);
luaT_setfuncs(L, torch_TensorMath__, 0);
}
Expand Down
2 changes: 0 additions & 2 deletions generic/Storage.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,8 @@ static int torch_Storage_(copy)(lua_State *L)
THStorage_(copyFloat)(storage, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
THStorage_(copyDouble)(storage, src);
#if TH_GENERIC_USE_HALF
else if( (src = luaT_toudata(L, 2, "torch.HalfStorage")) )
THStorage_(copyHalf)(storage, src);
#endif
else
luaL_typerror(L, 2, "torch.*Storage");
lua_settop(L, 1);
Expand Down
16 changes: 5 additions & 11 deletions generic/Tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ static int torch_Tensor_(select)(lua_State *L)
return 1;
}

#ifndef TH_GENERIC_NO_MATH
#ifndef TH_REAL_IS_HALF
static int torch_Tensor_(indexSelect)(lua_State *L)
{
int narg = lua_gettop(L);
Expand Down Expand Up @@ -677,10 +677,8 @@ static int torch_Tensor_(copy)(lua_State *L)
THTensor_(copyFloat)(tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
THTensor_(copyDouble)(tensor, src);
#if TH_GENERIC_USE_HALF
else if( (src = luaT_toudata(L, 2, "torch.HalfTensor")) )
THTensor_(copyHalf)(tensor, src);
#endif
else
luaL_typerror(L, 2, "torch.*Tensor");
lua_settop(L, 1);
Expand Down Expand Up @@ -755,13 +753,11 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyDouble)(tensor, src);
THTensor_(free)(tensor);
#if TH_GENERIC_USE_HALF
} else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyHalf)(tensor, src);
THTensor_(free)(tensor);
#endif
} else {
luaL_typerror(L, 3, "torch.*Tensor");
}
Expand Down Expand Up @@ -867,10 +863,8 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THTensor_(copyFloat)(tensor, src);
} else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) {
THTensor_(copyDouble)(tensor, src);
#if TH_GENERIC_USE_HALF
} else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) {
THTensor_(copyHalf)(tensor, src);
#endif
} else {
luaL_typerror(L, 3, "torch.*Tensor");
}
Expand Down Expand Up @@ -1165,7 +1159,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index
THArgCheck(0, index, "expecting number");
}

#ifndef TH_GENERIC_NO_MATH
#ifndef TH_REAL_IS_HALF
static int torch_Tensor_(apply)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
Expand Down Expand Up @@ -1322,7 +1316,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
{"narrow", torch_Tensor_(narrow)},
{"sub", torch_Tensor_(sub)},
{"select", torch_Tensor_(select)},
#ifndef TH_GENERIC_NO_MATH
#ifndef TH_REAL_IS_HALF
{"index", torch_Tensor_(indexSelect)},
{"indexCopy", torch_Tensor_(indexCopy)},
{"indexAdd", torch_Tensor_(indexAdd)},
Expand All @@ -1340,7 +1334,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
{"isSize", torch_Tensor_(isSize)},
{"nElement", torch_Tensor_(nElement)},
{"copy", torch_Tensor_(copy)},
#ifndef TH_GENERIC_NO_MATH
#ifndef TH_REAL_IS_HALF
{"apply", torch_Tensor_(apply)},
{"map", torch_Tensor_(map)},
{"map2", torch_Tensor_(map2)},
Expand All @@ -1358,7 +1352,7 @@ void torch_Tensor_(init)(lua_State *L)
torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
luaT_setfuncs(L, torch_Tensor_(_), 0);
lua_pop(L, 1);
#ifndef TH_GENERIC_NO_MATH
#ifndef TH_REAL_IS_HALF
THVector_(vectorDispatchInit)();
#endif

Expand Down
4 changes: 0 additions & 4 deletions generic/TensorOperator.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
#define TH_GENERIC_FILE "generic/TensorOperator.c"
#else

/* Tensor math may be disabled for certain types, e.g. 'half' */
#ifndef TH_GENERIC_NO_MATH

static int torch_TensorOperator_(__add__)(lua_State *L)
{
THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor);
Expand Down Expand Up @@ -190,6 +187,5 @@ void torch_TensorOperator_(init)(lua_State *L)
luaT_setfuncs(L, torch_TensorOperator_(_), 0);
lua_pop(L, 1);
}
#endif

#endif
24 changes: 0 additions & 24 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,9 @@ extern void torch_LongTensorOperator_init(lua_State *L);
extern void torch_FloatTensorOperator_init(lua_State *L);
extern void torch_DoubleTensorOperator_init(lua_State *L);

#if TH_NATIVE_HALF
extern void torch_HalfTensorOperator_init(lua_State *L);
#endif

extern void torch_TensorMath_init(lua_State *L);

static int torch_hashalfmath(lua_State *L) {
lua_pushboolean(L, TH_NATIVE_HALF);
return 1;
}

static void torch_half_init(lua_State *L)
{
const struct luaL_Reg half_funcs__ [] = {
{"hashalfmath", torch_hashalfmath},
{NULL, NULL}
};
luaT_setfuncs(L, half_funcs__, 0);

lua_pushboolean(L, 1);
lua_setfield(L, -2, "hasHalf");
}

LUA_EXTERNC DLL_EXPORT int luaopen_libtorch(lua_State *L);

Expand All @@ -69,7 +50,6 @@ int luaopen_libtorch(lua_State *L)

torch_utils_init(L);
torch_File_init(L);
torch_half_init(L);

torch_ByteStorage_init(L);
torch_CharStorage_init(L);
Expand Down Expand Up @@ -97,10 +77,6 @@ int luaopen_libtorch(lua_State *L)
torch_FloatTensorOperator_init(L);
torch_DoubleTensorOperator_init(L);

#if TH_NATIVE_HALF
torch_HalfTensorOperator_init(L);
#endif

torch_Timer_init(L);
torch_DiskFile_init(L);
torch_PipeFile_init(L);
Expand Down
1 change: 1 addition & 0 deletions lib/TH/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ INSTALL(FILES
THFilePrivate.h
${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h
THGenerateAllTypes.h
THGenerateHalfType.h
THGenerateFloatTypes.h
THGenerateIntTypes.h
THLapack.h
Expand Down
16 changes: 2 additions & 14 deletions lib/TH/THDiskFile.c
Original file line number Diff line number Diff line change
Expand Up @@ -348,18 +348,14 @@ READ_WRITE_METHODS(int, Int,
int ret = fscanf(dfself->handle, "%d", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%d", data[i]); if(ret <= 0) break; else nwrite++)

/*READ_WRITE_METHODS(long, Long,
int ret = fscanf(dfself->handle, "%ld", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%ld", data[i]); if(ret <= 0) break; else nwrite++)*/

READ_WRITE_METHODS(float, Float,
int ret = fscanf(dfself->handle, "%g", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%.9g", data[i]); if(ret <= 0) break; else nwrite++)
#if TH_GENERIC_USE_HALF

READ_WRITE_METHODS(THHalf, Half,
float buf; int ret = fscanf(dfself->handle, "%g", &buf); if(ret <= 0) break; else { data[i]= TH_float2half(buf); nread++; },
int ret = fprintf(dfself->handle, "%.9g", TH_half2float(data[i])); if(ret <= 0) break; else nwrite++)
#endif

READ_WRITE_METHODS(double, Double,
int ret = fscanf(dfself->handle, "%lg", &data[i]); if(ret <= 0) break; else nread++,
int ret = fprintf(dfself->handle, "%.17g", data[i]); if(ret <= 0) break; else nwrite++)
Expand Down Expand Up @@ -622,9 +618,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_readLong,
THDiskFile_readFloat,
THDiskFile_readDouble,
#if TH_GENERIC_USE_HALF
THDiskFile_readHalf,
#endif
THDiskFile_readString,

THDiskFile_writeByte,
Expand All @@ -634,9 +628,7 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_writeLong,
THDiskFile_writeFloat,
THDiskFile_writeDouble,
#if TH_GENERIC_USE_HALF
THDiskFile_writeHalf,
#endif
THDiskFile_writeString,

THDiskFile_synchronize,
Expand Down Expand Up @@ -740,9 +732,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_readLong,
THDiskFile_readFloat,
THDiskFile_readDouble,
#if TH_GENERIC_USE_HALF
THDiskFile_readHalf,
#endif
THDiskFile_readString,

THDiskFile_writeByte,
Expand All @@ -752,9 +742,7 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet)
THDiskFile_writeLong,
THDiskFile_writeFloat,
THDiskFile_writeDouble,
#if TH_GENERIC_USE_HALF
THDiskFile_writeHalf,
#endif
THDiskFile_writeString,

THDiskFile_synchronize,
Expand Down
6 changes: 0 additions & 6 deletions lib/TH/THFile.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ IMPLEMENT_THFILE_RW(Int, int)
IMPLEMENT_THFILE_RW(Long, long)
IMPLEMENT_THFILE_RW(Float, float)
IMPLEMENT_THFILE_RW(Double, double)
#if TH_GENERIC_USE_HALF
IMPLEMENT_THFILE_RW(Half, THHalf)
#endif

size_t THFile_readStringRaw(THFile *self, const char *format, char **str_)
{
Expand Down Expand Up @@ -136,9 +134,7 @@ IMPLEMENT_THFILE_SCALAR(Int, int)
IMPLEMENT_THFILE_SCALAR(Long, long)
IMPLEMENT_THFILE_SCALAR(Float, float)
IMPLEMENT_THFILE_SCALAR(Double, double)
#if TH_GENERIC_USE_HALF
IMPLEMENT_THFILE_SCALAR(Half, THHalf)
#endif

#define IMPLEMENT_THFILE_STORAGE(TYPEC, TYPE) \
size_t THFile_read##TYPEC(THFile *self, TH##TYPEC##Storage *storage) \
Expand All @@ -158,6 +154,4 @@ IMPLEMENT_THFILE_STORAGE(Int, int)
IMPLEMENT_THFILE_STORAGE(Long, long)
IMPLEMENT_THFILE_STORAGE(Float, float)
IMPLEMENT_THFILE_STORAGE(Double, double)
#if TH_GENERIC_USE_HALF
IMPLEMENT_THFILE_STORAGE(Half, THHalf)
#endif
2 changes: 0 additions & 2 deletions lib/TH/THFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,12 @@ TH_API size_t THFile_writeFloatRaw(THFile *self, float *data, size_t n);
TH_API size_t THFile_writeDoubleRaw(THFile *self, double *data, size_t n);
TH_API size_t THFile_writeStringRaw(THFile *self, const char *str, size_t size);

#if TH_GENERIC_USE_HALF
TH_API THHalf THFile_readHalfScalar(THFile *self);
TH_API void THFile_writeHalfScalar(THFile *self, THHalf scalar);
TH_API size_t THFile_readHalf(THFile *self, THHalfStorage *storage);
TH_API size_t THFile_writeHalf(THFile *self, THHalfStorage *storage);
TH_API size_t THFile_readHalfRaw(THFile *self, THHalf* data, size_t size);
TH_API size_t THFile_writeHalfRaw(THFile *self, THHalf* data, size_t size);
#endif

TH_API void THFile_synchronize(THFile *self);
TH_API void THFile_seek(THFile *self, size_t position);
Expand Down
8 changes: 1 addition & 7 deletions lib/TH/THFilePrivate.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "THGeneral.h"

#if TH_GENERIC_USE_HALF
# include "THHalf.h"
#endif
#include "THHalf.h"


struct THFile__
Expand Down Expand Up @@ -30,9 +28,7 @@ struct THFileVTable
size_t (*readLong)(THFile *self, long *data, size_t n);
size_t (*readFloat)(THFile *self, float *data, size_t n);
size_t (*readDouble)(THFile *self, double *data, size_t n);
#if TH_GENERIC_USE_HALF
size_t (*readHalf)(THFile *self, THHalf *data, size_t n);
#endif
size_t (*readString)(THFile *self, const char *format, char **str_);

size_t (*writeByte)(THFile *self, unsigned char *data, size_t n);
Expand All @@ -42,9 +38,7 @@ struct THFileVTable
size_t (*writeLong)(THFile *self, long *data, size_t n);
size_t (*writeFloat)(THFile *self, float *data, size_t n);
size_t (*writeDouble)(THFile *self, double *data, size_t n);
#if TH_GENERIC_USE_HALF
size_t (*writeHalf)(THFile *self, THHalf *data, size_t n);
#endif
size_t (*writeString)(THFile *self, const char *str, size_t size);

void (*synchronize)(THFile *self);
Expand Down
8 changes: 0 additions & 8 deletions lib/TH/THGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@
#define TH_INDEX_BASE 1
#endif

#ifndef TH_GENERIC_USE_HALF
# define TH_GENERIC_USE_HALF 0
#endif

#ifndef TH_NATIVE_HALF
# define TH_NATIVE_HALF 0
#endif

typedef void (*THErrorHandlerFunction)(const char *msg, void *data);
typedef void (*THArgErrorHandlerFunction)(int argNumber, const char *msg, void *data);

Expand Down
Loading

0 comments on commit 1e86025

Please sign in to comment.