diff --git a/File.lua b/File.lua index 62249a36..8ef9c71a 100644 --- a/File.lua +++ b/File.lua @@ -376,15 +376,21 @@ function File:readObject() end end --- simple helpers to save/load arbitrary objects/tables +-- simple helpers to save/load arbitrary objects/tables function torch.save(filename, object, mode, referenced) - assert(mode == nil or mode == 'binary' or mode == 'ascii', '"binary" or "ascii" (or nil) expected for mode') + assert(mode == nil or mode == 'binary' or mode == 'b32' or mode == 'b64' or mode == 'ascii', '"binary" or "ascii" (or nil) expected for mode') assert(referenced == nil or referenced == true or referenced == false, 'true or false (or nil) expected for referenced') + local longSize + if mode == 'b32' or mode == 'b64' then + longSize = tonumber(mode:match('%d+')) / 8 + mode = 'binary' + end mode = mode or 'binary' referenced = referenced == nil and true or referenced local file = torch.DiskFile(filename, 'w') file[mode](file) file:referenced(referenced) + if longSize then file:longSize(longSize) end file:writeObject(object) file:close() end diff --git a/lib/TH/THGeneral.h.in b/lib/TH/THGeneral.h.in index de11f1b1..b86bad2e 100644 --- a/lib/TH/THGeneral.h.in +++ b/lib/TH/THGeneral.h.in @@ -14,6 +14,7 @@ #cmakedefine USE_BLAS #cmakedefine USE_LAPACK #cmakedefine BLAS_F2C +#cmakedefine MKL_ILP64 #ifdef __cplusplus # define TH_EXTERNC extern "C" diff --git a/lib/TH/THLapack.h b/lib/TH/THLapack.h index 614d15f9..7f715542 100644 --- a/lib/TH/THLapack.h +++ b/lib/TH/THLapack.h @@ -21,6 +21,17 @@ if (info < 0) { \ THError(fmt, func, info, ##__VA_ARGS__); \ } +#ifdef MKL_ILP64 +// set 64 bit MKL integer type +#if (!defined(__INTEL_COMPILER)) & defined(_MSC_VER) +#define LAPACK_INT __int64 +#else +#define LAPACK_INT long long int +#endif +#else +#define LAPACK_INT int +#endif + #include "generic/THLapack.h" #include "THGenerateAllTypes.h" diff --git a/lib/TH/cmake/FindBLAS.cmake b/lib/TH/cmake/FindBLAS.cmake index 2188fc72..8384bb70 100644 --- a/lib/TH/cmake/FindBLAS.cmake +++ b/lib/TH/cmake/FindBLAS.cmake @@ -242,32 +242,70 @@ endif() # Determine if blas was compiled with the f2c conventions IF (BLAS_LIBRARIES) SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) - CHECK_C_SOURCE_RUNS(" + IF (MKL_ILP64) + SET(CMAKE_REQUIRED_DEFINITIONS -DMKL_ILP64) + MESSAGE(STATUS "Checking F2C with MKL ILP64 ${CMAKE_REQUIRED_DEFINITIONS}") + ENDIF(MKL_ILP64) + + set(f2c_code_d " #include #include float x[4] = { 1, 2, 3, 4 }; float y[4] = { .1, .01, .001, .0001 }; -int four = 4; -int one = 1; +#ifdef MKL_ILP64 + #if (!defined(__INTEL_COMPILER)) & defined(_MSC_VER) + #define BLAS_INT __int64 + #else + #define BLAS_INT long long + #endif +#else + #define BLAS_INT int +#endif +BLAS_INT four = 4; +BLAS_INT one = 1; extern double sdot_(); int main() { - int i; double r = sdot_(&four, x, &one, y, &one); exit((float)r != (float).1234); -}" BLAS_F2C_DOUBLE_WORKS ) - CHECK_C_SOURCE_RUNS(" +}" ) + + CHECK_C_SOURCE_COMPILES("${f2c_code_d}" BLAS_F2C_DOUBLE_COMPILES ) + IF (NOT BLAS_F2C_DOUBLE_COMPILES) + MESSAGE(STATUS "Warning F2C double check did not compile!!") + MESSAGE(STATUS "${f2c_code_d}") + ENDIF(NOT BLAS_F2C_DOUBLE_COMPILES) + + CHECK_C_SOURCE_RUNS("${f2c_code_d}" BLAS_F2C_DOUBLE_WORKS ) + + set(f2c_code_f " #include #include float x[4] = { 1, 2, 3, 4 }; float y[4] = { .1, .01, .001, .0001 }; -int four = 4; -int one = 1; +#ifdef MKL_ILP64 + #if (!defined(__INTEL_COMPILER)) & defined(_MSC_VER) + #define BLAS_INT __int64 + #else + #define BLAS_INT long long + #endif +#else + #define BLAS_INT int +#endif +BLAS_INT four = 4; +BLAS_INT one = 1; extern float sdot_(); int main() { - int i; double r = sdot_(&four, x, &one, y, &one); exit((float)r != (float).1234); -}" BLAS_F2C_FLOAT_WORKS ) +}" ) + + CHECK_C_SOURCE_COMPILES("${f2c_code_f}" BLAS_F2C_FLOAT_COMPILES ) + IF (NOT BLAS_F2C_FLOAT_COMPILES) + MESSAGE(STATUS "Warning F2C float check did not compile!!") + ENDIF(NOT BLAS_F2C_FLOAT_COMPILES) + + CHECK_C_SOURCE_RUNS("${f2c_code_f}" BLAS_F2C_FLOAT_WORKS ) + IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) MESSAGE(STATUS "This BLAS uses the F2C return conventions") SET(BLAS_F2C TRUE) diff --git a/lib/TH/cmake/FindMKL.cmake b/lib/TH/cmake/FindMKL.cmake index 7c9325a7..d949dee7 100644 --- a/lib/TH/cmake/FindMKL.cmake +++ b/lib/TH/cmake/FindMKL.cmake @@ -29,19 +29,27 @@ INCLUDE(CheckTypeSize) INCLUDE(CheckFunctionExists) # Intel Compiler Suite -SET(INTEL_COMPILER_DIR CACHE STRING +SET(INTEL_COMPILER_DIR $ENV{INTEL_COMPILER_DIR} CACHE STRING "Root directory of the Intel Compiler Suite (contains ipp, mkl, etc.)") -SET(INTEL_MKL_DIR CACHE STRING +SET(INTEL_MKL_DIR $ENV{INTEL_MKL_DIR} CACHE STRING "Root directory of the Intel MKL (standalone)") +SET(MKL_ILP64 $ENV{MKL_ILP64} CACHE STRING + "Link with 64bit-interger version of MKL (_ilp64 instead of _lp64)") SET(INTEL_MKL_SEQUENTIAL OFF CACHE BOOL "Force using the sequential (non threaded) libraries") +MESSAGE(STATUS "INTEL_MKL_DIR: ${INTEL_MKL_DIR}") + # Checks CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP) IF ("${SIZE_OF_VOIDP}" EQUAL 8) - SET(mklvers "em64t") + SET(mklvers "intel64") SET(iccvers "intel64") - SET(mkl64s "_lp64") + IF (MKL_ILP64) + SET(mkl64s "_ilp64") + ELSE(MKL_ILP64) + SET(mkl64s "_lp64") + ENDIF(MKL_ILP64) ELSE ("${SIZE_OF_VOIDP}" EQUAL 8) SET(mklvers "32") SET(iccvers "ia32") @@ -80,15 +88,26 @@ ENDIF (INTEL_COMPILER_DIR) IF (INTEL_MKL_DIR) # TODO: diagnostic if dir does not exist SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} - "${INTEL_MKL_DIR}/include") + "${INTEL_MKL_DIR}/include/") SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} - "${INTEL_MKL_DIR}/lib/${mklvers}") + "${INTEL_MKL_DIR}/lib/${mklvers}/") IF (MSVC) SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} - "${INTEL_MKL_DIR}/lib/${iccvers}") + "${INTEL_MKL_DIR}/lib/${iccvers}/") ENDIF (MSVC) ENDIF (INTEL_MKL_DIR) +# lib prefix +IF (MSVC) + SET(CMAKE_FIND_LIBRARY_PREFIXES "") + SET(CMAKE_FIND_LIBRARY_SUFFIXES ".lib" ".dll") +ELSE(MSVC) + SET(CMAKE_FIND_LIBRARY_PREFIXES "lib") + SET(CMAKE_FIND_LIBRARY_SUFFIXES ".so" ".a") +ENDIF (MSVC) + +MESSAGE(STATUS "Searching for MKL in ${CMAKE_LIBRARY_PATH} ...") + # Try linking multiple libs MACRO(CHECK_ALL_LIBRARIES LIBRARIES _name _list _flags) # This macro checks for the existence of the combination of libraries given by _list. @@ -258,9 +277,15 @@ ENDIF (MKL_LIBRARIES) IF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) MESSAGE(FATAL_ERROR "MKL library not found. Please specify library location") ENDIF(NOT MKL_FOUND AND MKL_FIND_REQUIRED) + + IF(NOT MKL_FIND_QUIETLY) IF(MKL_FOUND) - MESSAGE(STATUS "MKL library found") + IF (mkl64s) + MESSAGE(STATUS "MKL 64bit library found: ${mkl64s}") + ELSE(mkl64s) + MESSAGE(STATUS "MKL 32bit library found: ${mkl64s}") + ENDIF(mkl64s) ELSE(MKL_FOUND) MESSAGE(STATUS "MKL library not found") ENDIF(MKL_FOUND) diff --git a/lib/TH/generic/THBlas.c b/lib/TH/generic/THBlas.c index b04931f3..1f58060b 100644 --- a/lib/TH/generic/THBlas.c +++ b/lib/TH/generic/THBlas.c @@ -9,24 +9,35 @@ # define ffloat float #endif -TH_EXTERNC void dswap_(int *n, double *x, int *incx, double *y, int *incy); -TH_EXTERNC void sswap_(int *n, float *x, int *incx, float *y, int *incy); -TH_EXTERNC void dscal_(int *n, double *a, double *x, int *incx); -TH_EXTERNC void sscal_(int *n, float *a, float *x, int *incx); -TH_EXTERNC void dcopy_(int *n, double *x, int *incx, double *y, int *incy); -TH_EXTERNC void scopy_(int *n, float *x, int *incx, float *y, int *incy); -TH_EXTERNC void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy); -TH_EXTERNC void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy); -TH_EXTERNC double ddot_(int *n, double *x, int *incx, double *y, int *incy); -TH_EXTERNC ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy); -TH_EXTERNC void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy); -TH_EXTERNC void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy); -TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda); -TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda); -TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc); -TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc); +#ifdef MKL_ILP64 + // set 64 bit MKL integer type + #if (!defined(__INTEL_COMPILER)) & defined(_MSC_VER) + #define BLAS_INT __int64 + #else + #define BLAS_INT long long int + #endif +#else + #define BLAS_INT int +#endif +TH_EXTERNC void dswap_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); +TH_EXTERNC void sswap_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); +TH_EXTERNC void dscal_(BLAS_INT *n, double *a, double *x, BLAS_INT *incx); +TH_EXTERNC void sscal_(BLAS_INT *n, float *a, float *x, BLAS_INT *incx); +TH_EXTERNC void dcopy_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); +TH_EXTERNC void scopy_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); +TH_EXTERNC void daxpy_(BLAS_INT *n, double *a, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); +TH_EXTERNC void saxpy_(BLAS_INT *n, float *a, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); +TH_EXTERNC double ddot_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy); +TH_EXTERNC ffloat sdot_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy); +TH_EXTERNC void dgemv_(char *trans, BLAS_INT *m, BLAS_INT *n, double *alpha, double *a, BLAS_INT *lda, double *x, BLAS_INT *incx, double *beta, double *y, BLAS_INT *incy); +TH_EXTERNC void sgemv_(char *trans, BLAS_INT *m, BLAS_INT *n, float *alpha, float *a, BLAS_INT *lda, float *x, BLAS_INT *incx, float *beta, float *y, BLAS_INT *incy); +TH_EXTERNC void dger_(BLAS_INT *m, BLAS_INT *n, double *alpha, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy, double *a, BLAS_INT *lda); +TH_EXTERNC void sger_(BLAS_INT *m, BLAS_INT *n, float *alpha, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy, float *a, BLAS_INT *lda); +TH_EXTERNC void dgemm_(char *transa, char *transb, BLAS_INT *m, BLAS_INT *n, BLAS_INT *k, double *alpha, double *a, BLAS_INT *lda, double *b, BLAS_INT *ldb, double *beta, double *c, BLAS_INT *ldc); +TH_EXTERNC void sgemm_(char *transa, char *transb, BLAS_INT *m, BLAS_INT *n, BLAS_INT *k, float *alpha, float *a, BLAS_INT *lda, float *b, BLAS_INT *ldb, float *beta, float *c, BLAS_INT *ldc); + void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) { @@ -39,9 +50,9 @@ void THBlas_(swap)(long n, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_incx = (BLAS_INT)incx; + BLAS_INT i_incy = (BLAS_INT)incy; #if defined(TH_REAL_IS_DOUBLE) dswap_(&i_n, x, &i_incx, y, &i_incy); @@ -70,8 +81,8 @@ void THBlas_(scal)(long n, real a, real *x, long incx) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_incx = (BLAS_INT)incx; #if defined(TH_REAL_IS_DOUBLE) dscal_(&i_n, &a, x, &i_incx); @@ -83,13 +94,8 @@ void THBlas_(scal)(long n, real a, real *x, long incx) #endif { long i; - for(i = 0; i < n; i++) { - if (a == 0) { - x[i*incx] = 0; - } else { - x[i*incx] *= a; - } - } + for(i = 0; i < n; i++) + x[i*incx] *= a; } } @@ -104,9 +110,9 @@ void THBlas_(copy)(long n, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_incx = (BLAS_INT)incx; + BLAS_INT i_incy = (BLAS_INT)incy; #if defined(TH_REAL_IS_DOUBLE) dcopy_(&i_n, x, &i_incx, y, &i_incy); @@ -134,9 +140,9 @@ void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_incx = (BLAS_INT)incx; + BLAS_INT i_incy = (BLAS_INT)incy; #if defined(TH_REAL_IS_DOUBLE) daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); @@ -164,9 +170,9 @@ real THBlas_(dot)(long n, real *x, long incx, real *y, long incy) #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_incx = (BLAS_INT)incx; + BLAS_INT i_incy = (BLAS_INT)incy; #if defined(TH_REAL_IS_DOUBLE) return (real) ddot_(&i_n, x, &i_incx, y, &i_incy); @@ -195,11 +201,11 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re (incx > 0) && (incx <= INT_MAX) && (incy > 0) && (incy <= INT_MAX) ) { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; + BLAS_INT i_m = (BLAS_INT)m; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_lda = (BLAS_INT)lda; + BLAS_INT i_incx = (BLAS_INT)incx; + BLAS_INT i_incy = (BLAS_INT)incy; #if defined(TH_REAL_IS_DOUBLE) dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); @@ -250,11 +256,11 @@ void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; + BLAS_INT i_m = (BLAS_INT)m; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_lda = (BLAS_INT)lda; + BLAS_INT i_incx = (BLAS_INT)incx; + BLAS_INT i_incy = (BLAS_INT)incy; #if defined(TH_REAL_IS_DOUBLE) dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda); @@ -309,12 +315,12 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha, #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) { - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_lda = (int)lda; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; + BLAS_INT i_m = (BLAS_INT)m; + BLAS_INT i_n = (BLAS_INT)n; + BLAS_INT i_k = (BLAS_INT)k; + BLAS_INT i_lda = (BLAS_INT)lda; + BLAS_INT i_ldb = (BLAS_INT)ldb; + BLAS_INT i_ldc = (BLAS_INT)ldc; #if defined(TH_REAL_IS_DOUBLE) dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc); diff --git a/lib/TH/generic/THLapack.c b/lib/TH/generic/THLapack.c index 148ae26c..910c19ca 100644 --- a/lib/TH/generic/THLapack.c +++ b/lib/TH/generic/THLapack.c @@ -2,43 +2,42 @@ #define TH_GENERIC_FILE "generic/THLapack.c" #else - -TH_EXTERNC void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); -TH_EXTERNC void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); -TH_EXTERNC void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); -TH_EXTERNC void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); -TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); -TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); -TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); -TH_EXTERNC void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); -TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); -TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); -TH_EXTERNC void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info); -TH_EXTERNC void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info); -TH_EXTERNC void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info); -TH_EXTERNC void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); -TH_EXTERNC void dgetrs_(char *trans, int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); -TH_EXTERNC void sgetrs_(char *trans, int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); -TH_EXTERNC void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info); -TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info); -TH_EXTERNC void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info); -TH_EXTERNC void spotrf_(char *uplo, int *n, float *a, int *lda, int *info); -TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info); -TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info); -TH_EXTERNC void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); -TH_EXTERNC void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); -TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info); -TH_EXTERNC void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info); -TH_EXTERNC void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info); -TH_EXTERNC void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info); -TH_EXTERNC void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info); -TH_EXTERNC void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info); -TH_EXTERNC void spstrf_(char *uplo, int *n, float *a, int *lda, int *piv, int *rank, float *tol, float *work, int *info); -TH_EXTERNC void dpstrf_(char *uplo, int *n, double *a, int *lda, int *piv, int *rank, double *tol, double *work, int *info); +TH_EXTERNC void dgesv_(LAPACK_INT *n, LAPACK_INT *nrhs, double *a, LAPACK_INT *lda, LAPACK_INT *ipiv, double *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void sgesv_(LAPACK_INT *n, LAPACK_INT *nrhs, float *a, LAPACK_INT *lda, LAPACK_INT *ipiv, float *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void dtrtrs_(char *uplo, char *trans, char *diag, LAPACK_INT *n, LAPACK_INT *nrhs, double *a, LAPACK_INT *lda, double *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void strtrs_(char *uplo, char *trans, char *diag, LAPACK_INT *n, LAPACK_INT *nrhs, float *a, LAPACK_INT *lda, float *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void dgels_(char *trans, LAPACK_INT *m, LAPACK_INT *n, LAPACK_INT *nrhs, double *a, LAPACK_INT *lda, double *b, LAPACK_INT *ldb, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void sgels_(char *trans, LAPACK_INT *m, LAPACK_INT *n, LAPACK_INT *nrhs, float *a, LAPACK_INT *lda, float *b, LAPACK_INT *ldb, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dsyev_(char *jobz, char *uplo, LAPACK_INT *n, double *a, LAPACK_INT *lda, double *w, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void ssyev_(char *jobz, char *uplo, LAPACK_INT *n, float *a, LAPACK_INT *lda, float *w, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, LAPACK_INT *n, double *a, LAPACK_INT *lda, double *wr, double *wi, double* vl, LAPACK_INT *ldvl, double *vr, LAPACK_INT *ldvr, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, LAPACK_INT *n, float *a, LAPACK_INT *lda, float *wr, float *wi, float* vl, LAPACK_INT *ldvl, float *vr, LAPACK_INT *ldvr, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dgesvd_(char *jobu, char *jobvt, LAPACK_INT *m, LAPACK_INT *n, double *a, LAPACK_INT *lda, double *s, double *u, LAPACK_INT *ldu, double *vt, LAPACK_INT *ldvt, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void sgesvd_(char *jobu, char *jobvt, LAPACK_INT *m, LAPACK_INT *n, float *a, LAPACK_INT *lda, float *s, float *u, LAPACK_INT *ldu, float *vt, LAPACK_INT *ldvt, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dgetrf_(LAPACK_INT *m, LAPACK_INT *n, double *a, LAPACK_INT *lda, LAPACK_INT *ipiv, LAPACK_INT *info); +TH_EXTERNC void sgetrf_(LAPACK_INT *m, LAPACK_INT *n, float *a, LAPACK_INT *lda, LAPACK_INT *ipiv, LAPACK_INT *info); +TH_EXTERNC void dgetrs_(char *trans, LAPACK_INT *n, LAPACK_INT *nrhs, double *a, LAPACK_INT *lda, LAPACK_INT *ipiv, double *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void sgetrs_(char *trans, LAPACK_INT *n, LAPACK_INT *nrhs, float *a, LAPACK_INT *lda, LAPACK_INT *ipiv, float *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void dgetri_(LAPACK_INT *n, double *a, LAPACK_INT *lda, LAPACK_INT *ipiv, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void sgetri_(LAPACK_INT *n, float *a, LAPACK_INT *lda, LAPACK_INT *ipiv, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dpotrf_(char *uplo, LAPACK_INT *n, double *a, LAPACK_INT *lda, LAPACK_INT *info); +TH_EXTERNC void spotrf_(char *uplo, LAPACK_INT *n, float *a, LAPACK_INT *lda, LAPACK_INT *info); +TH_EXTERNC void dpotri_(char *uplo, LAPACK_INT *n, double *a, LAPACK_INT *lda, LAPACK_INT *info); +TH_EXTERNC void spotri_(char *uplo, LAPACK_INT *n, float *a, LAPACK_INT *lda, LAPACK_INT *info); +TH_EXTERNC void dpotrs_(char *uplo, LAPACK_INT *n, LAPACK_INT *nrhs, double *a, LAPACK_INT *lda, double *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void spotrs_(char *uplo, LAPACK_INT *n, LAPACK_INT *nrhs, float *a, LAPACK_INT *lda, float *b, LAPACK_INT *ldb, LAPACK_INT *info); +TH_EXTERNC void sgeqrf_(LAPACK_INT *m, LAPACK_INT *n, float *a, LAPACK_INT *lda, float *tau, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dgeqrf_(LAPACK_INT *m, LAPACK_INT *n, double *a, LAPACK_INT *lda, double *tau, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void sorgqr_(LAPACK_INT *m, LAPACK_INT *n, LAPACK_INT *k, float *a, LAPACK_INT *lda, float *tau, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dorgqr_(LAPACK_INT *m, LAPACK_INT *n, LAPACK_INT *k, double *a, LAPACK_INT *lda, double *tau, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void sormqr_(char *side, char *trans, LAPACK_INT *m, LAPACK_INT *n, LAPACK_INT *k, float *a, LAPACK_INT *lda, float *tau, float *c, LAPACK_INT *ldc, float *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void dormqr_(char *side, char *trans, LAPACK_INT *m, LAPACK_INT *n, LAPACK_INT *k, double *a, LAPACK_INT *lda, double *tau, double *c, LAPACK_INT *ldc, double *work, LAPACK_INT *lwork, LAPACK_INT *info); +TH_EXTERNC void spstrf_(char *uplo, LAPACK_INT *n, float *a, LAPACK_INT *lda, LAPACK_INT *piv, LAPACK_INT *rank, float *tol, float *work, LAPACK_INT *info); +TH_EXTERNC void dpstrf_(char *uplo, LAPACK_INT *n, double *a, LAPACK_INT *lda, LAPACK_INT *piv, LAPACK_INT *rank, double *tol, double *work, LAPACK_INT *info); /* Compute the solution to a real system of linear equations A * X = B */ -void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info) +void THLapack_(gesv)(LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, real *b, LAPACK_INT ldb, LAPACK_INT* info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -53,7 +52,7 @@ void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int } /* Solve a triangular system of the form A * X = B or A^T * X = B */ -void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a, int lda, real *b, int ldb, int* info) +void THLapack_(trtrs)(char uplo, char trans, char diag, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, real *b, LAPACK_INT ldb, LAPACK_INT* info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -69,7 +68,7 @@ void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a /* Solve overdetermined or underdetermined real linear systems involving an M-by-N matrix A, or its transpose, using a QR or LQ factorization of A */ -void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info) +void THLapack_(gels)(char trans, LAPACK_INT m, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, real *b, LAPACK_INT ldb, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -84,7 +83,7 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real /* Compute all eigenvalues and, optionally, eigenvectors of a real symmetric matrix A */ -void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info) +void THLapack_(syev)(char jobz, char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, real *w, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -99,7 +98,7 @@ void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, rea /* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and, optionally, the left and/or right eigenvectors */ -void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr, real *wi, real* vl, int ldvl, real *vr, int ldvr, real *work, int lwork, int *info) +void THLapack_(geev)(char jobvl, char jobvr, LAPACK_INT n, real *a, LAPACK_INT lda, real *wr, real *wi, real* vl, LAPACK_INT ldvl, real *vr, LAPACK_INT ldvr, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -114,7 +113,7 @@ void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr, /* Compute the singular value decomposition (SVD) of a real M-by-N matrix A, optionally computing the left and/or right singular vectors */ -void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info) +void THLapack_(gesvd)(char jobu, char jobvt, LAPACK_INT m, LAPACK_INT n, real *a, LAPACK_INT lda, real *s, real *u, LAPACK_INT ldu, real *vt, LAPACK_INT ldvt, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -128,7 +127,7 @@ void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, rea } /* LU decomposition */ -void THLapack_(getrf)(int m, int n, real *a, int lda, int *ipiv, int *info) +void THLapack_(getrf)(LAPACK_INT m, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -141,7 +140,7 @@ void THLapack_(getrf)(int m, int n, real *a, int lda, int *ipiv, int *info) #endif } -void THLapack_(getrs)(char trans, int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int *info) +void THLapack_(getrs)(char trans, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, real *b, LAPACK_INT ldb, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -155,7 +154,7 @@ void THLapack_(getrs)(char trans, int n, int nrhs, real *a, int lda, int *ipiv, } /* Matrix Inverse */ -void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int lwork, int* info) +void THLapack_(getri)(LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, real *work, LAPACK_INT lwork, LAPACK_INT* info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -169,7 +168,7 @@ void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int lwork, } /* Cholesky factorization */ -void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info) +void THLapack_(potrf)(char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -183,7 +182,7 @@ void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info) } /* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */ -void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info) +void THLapack_(potrs)(char uplo, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, real *b, LAPACK_INT ldb, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -197,7 +196,7 @@ void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int } /* Cholesky factorization based Matrix Inverse */ -void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info) +void THLapack_(potri)(char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -211,7 +210,7 @@ void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info) } /* Cholesky factorization with complete pivoting */ -void THLapack_(pstrf)(char uplo, int n, real *a, int lda, int *piv, int *rank, real tol, real *work, int *info) +void THLapack_(pstrf)(char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *piv, LAPACK_INT *rank, real tol, real *work, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -225,7 +224,7 @@ void THLapack_(pstrf)(char uplo, int n, real *a, int lda, int *piv, int *rank, r } /* QR decomposition */ -void THLapack_(geqrf)(int m, int n, real *a, int lda, real *tau, real *work, int lwork, int *info) +void THLapack_(geqrf)(LAPACK_INT m, LAPACK_INT n, real *a, LAPACK_INT lda, real *tau, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -239,7 +238,7 @@ void THLapack_(geqrf)(int m, int n, real *a, int lda, real *tau, real *work, int } /* Build Q from output of geqrf */ -void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *work, int lwork, int *info) +void THLapack_(orgqr)(LAPACK_INT m, LAPACK_INT n, LAPACK_INT k, real *a, LAPACK_INT lda, real *tau, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) @@ -253,7 +252,7 @@ void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *wo } /* Multiply Q with a matrix using the output of geqrf */ -void THLapack_(ormqr)(char side, char trans, int m, int n, int k, real *a, int lda, real *tau, real *c, int ldc, real *work, int lwork, int *info) +void THLapack_(ormqr)(char side, char trans, LAPACK_INT m, LAPACK_INT n, LAPACK_INT k, real *a, LAPACK_INT lda, real *tau, real *c, LAPACK_INT ldc, real *work, LAPACK_INT lwork, LAPACK_INT *info) { #ifdef USE_LAPACK #if defined(TH_REAL_IS_DOUBLE) diff --git a/lib/TH/generic/THLapack.h b/lib/TH/generic/THLapack.h index b464dd2d..1aafd658 100644 --- a/lib/TH/generic/THLapack.h +++ b/lib/TH/generic/THLapack.h @@ -3,38 +3,38 @@ #else /* AX=B */ -TH_API void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info); +TH_API void THLapack_(gesv)(LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, real *b, LAPACK_INT ldb, LAPACK_INT* info); /* Solve a triangular system of the form A * X = B or A^T * X = B */ -TH_API void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, real *a, int lda, real *b, int ldb, int* info); +TH_API void THLapack_(trtrs)(char uplo, char trans, char diag, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, real *b, LAPACK_INT ldb, LAPACK_INT* info); /* ||AX-B|| */ -TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real *b, int ldb, real *work, int lwork, int *info); +TH_API void THLapack_(gels)(char trans, LAPACK_INT m, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, real *b, LAPACK_INT ldb, real *work, LAPACK_INT lwork, LAPACK_INT *info); /* Eigenvals */ -TH_API void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, real *work, int lwork, int *info); +TH_API void THLapack_(syev)(char jobz, char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, real *w, real *work, LAPACK_INT lwork, LAPACK_INT *info); /* Non-sym eigenvals */ -TH_API void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr, real *wi, real* vl, int ldvl, real *vr, int ldvr, real *work, int lwork, int *info); +TH_API void THLapack_(geev)(char jobvl, char jobvr, LAPACK_INT n, real *a, LAPACK_INT lda, real *wr, real *wi, real* vl, LAPACK_INT ldvl, real *vr, LAPACK_INT ldvr, real *work, LAPACK_INT lwork, LAPACK_INT *info); /* svd */ -TH_API void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, real *s, real *u, int ldu, real *vt, int ldvt, real *work, int lwork, int *info); +TH_API void THLapack_(gesvd)(char jobu, char jobvt, LAPACK_INT m, LAPACK_INT n, real *a, LAPACK_INT lda, real *s, real *u, LAPACK_INT ldu, real *vt, LAPACK_INT ldvt, real *work, LAPACK_INT lwork, LAPACK_INT *info); /* LU decomposition */ -TH_API void THLapack_(getrf)(int m, int n, real *a, int lda, int *ipiv, int *info); -TH_API void THLapack_(getrs)(char trans, int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int *info); +TH_API void THLapack_(getrf)(LAPACK_INT m, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, LAPACK_INT *info); +TH_API void THLapack_(getrs)(char trans, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, real *b, LAPACK_INT ldb, LAPACK_INT *info); /* Matrix Inverse */ -TH_API void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int lwork, int* info); +TH_API void THLapack_(getri)(LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *ipiv, real *work, LAPACK_INT lwork, LAPACK_INT* info); /* Positive Definite matrices */ /* Cholesky factorization */ -void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info); +void THLapack_(potrf)(char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *info); /* Matrix inverse based on Cholesky factorization */ -void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info); +void THLapack_(potri)(char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *info); /* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */ -void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info); +void THLapack_(potrs)(char uplo, LAPACK_INT n, LAPACK_INT nrhs, real *a, LAPACK_INT lda, real *b, LAPACK_INT ldb, LAPACK_INT *info); /* Cholesky factorization with complete pivoting. */ -void THLapack_(pstrf)(char uplo, int n, real *a, int lda, int *piv, int *rank, real tol, real *work, int *info); +void THLapack_(pstrf)(char uplo, LAPACK_INT n, real *a, LAPACK_INT lda, LAPACK_INT *piv, LAPACK_INT *rank, real tol, real *work, LAPACK_INT *info); /* QR decomposition */ -void THLapack_(geqrf)(int m, int n, real *a, int lda, real *tau, real *work, int lwork, int *info); +void THLapack_(geqrf)(LAPACK_INT m, LAPACK_INT n, real *a, LAPACK_INT lda, real *tau, real *work, LAPACK_INT lwork, LAPACK_INT *info); /* Build Q from output of geqrf */ -void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *work, int lwork, int *info); +void THLapack_(orgqr)(LAPACK_INT m, LAPACK_INT n, LAPACK_INT k, real *a, LAPACK_INT lda, real *tau, real *work, LAPACK_INT lwork, LAPACK_INT *info); /* Multiply Q with a matrix from output of geqrf */ -void THLapack_(ormqr)(char side, char trans, int m, int n, int k, real *a, int lda, real *tau, real *c, int ldc, real *work, int lwork, int *info); +void THLapack_(ormqr)(char side, char trans, LAPACK_INT m, LAPACK_INT n, LAPACK_INT k, real *a, LAPACK_INT lda, real *tau, real *c, LAPACK_INT ldc, real *work, LAPACK_INT lwork, LAPACK_INT *info); #endif diff --git a/lib/TH/generic/THTensorLapack.c b/lib/TH/generic/THTensorLapack.c index d0196c98..17c9e5c6 100644 --- a/lib/TH/generic/THTensorLapack.c +++ b/lib/TH/generic/THTensorLapack.c @@ -121,35 +121,35 @@ void THTensor_(gesv)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) free_b = 1; } - int n, nrhs, lda, ldb, info; - THIntTensor *ipiv; + LAPACK_INT n, nrhs, lda, ldb, info; + LAPACK_INT *ipiv; THTensor *ra__; // working version of A matrix to be passed into lapack GELS THTensor *rb__; // working version of B matrix to be passed into lapack GELS ra__ = THTensor_(cloneColumnMajor)(ra_, a); rb__ = THTensor_(cloneColumnMajor)(rb_, b); - n = (int)ra__->size[0]; - nrhs = (int)rb__->size[1]; + n = (LAPACK_INT)ra__->size[0]; + nrhs = (LAPACK_INT)rb__->size[1]; lda = n; ldb = n; - ipiv = THIntTensor_newWithSize1d((long)n); + ipiv = (LAPACK_INT*)THAlloc(n * sizeof(LAPACK_INT)); THLapack_(gesv)(n, nrhs, - THTensor_(data)(ra__), lda, THIntTensor_data(ipiv), + THTensor_(data)(ra__), lda, ipiv, THTensor_(data)(rb__), ldb, &info); THLapackCheckWithCleanup("Lapack Error in %s : U(%d,%d) is zero, singular U.", THCleanup( THTensor_(free)(ra__); THTensor_(free)(rb__); - THIntTensor_free(ipiv); + THFree(ipiv); if (free_b) THTensor_(free)(b);), "gesv", info, info); THTensor_(freeCopyTo)(ra__, ra_); THTensor_(freeCopyTo)(rb__, rb_); - THIntTensor_free(ipiv); + THFree(ipiv); if (free_b) THTensor_(free)(b); } @@ -174,15 +174,15 @@ void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a, free_b = 1; } - int n, nrhs, lda, ldb, info; + LAPACK_INT n, nrhs, lda, ldb, info; THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS ra__ = THTensor_(cloneColumnMajor)(ra_, a); rb__ = THTensor_(cloneColumnMajor)(rb_, b); - n = (int)ra__->size[0]; - nrhs = (int)rb__->size[1]; + n = (LAPACK_INT)ra__->size[0]; + nrhs = (LAPACK_INT)rb__->size[1]; lda = n; ldb = n; @@ -222,7 +222,7 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) free_b = 1; } - int m, n, nrhs, lda, ldb, info, lwork; + LAPACK_INT m, n, nrhs, lda, ldb, info, lwork; THTensor *work = NULL; real wkopt = 0; @@ -231,8 +231,8 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) ra__ = THTensor_(cloneColumnMajor)(ra_, a); - m = ra__->size[0]; - n = ra__->size[1]; + m = (LAPACK_INT)ra__->size[0]; + n = (LAPACK_INT)ra__->size[1]; lda = m; ldb = (m > n) ? m : n; @@ -272,7 +272,7 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr) { - int n, lda, lwork, info, ldvr; + LAPACK_INT n, lda, lwork, info, ldvr; THTensor *work, *wi, *wr, *a; real wkopt; real *rv_data; @@ -287,7 +287,7 @@ void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *job /* we want to definitely clone a_ for geev*/ a = THTensor_(cloneColumnMajor)(NULL, a_); - n = a->size[0]; + n = (LAPACK_INT)a->size[0]; lda = n; wi = THTensor_(newWithSize1d)(n); @@ -310,7 +310,7 @@ void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *job THLapack_(geev)('N', jobvr[0], n, THTensor_(data)(a), lda, THTensor_(data)(wr), THTensor_(data)(wi), NULL, 1, rv_data, ldvr, &wkopt, -1, &info); - lwork = (int)wkopt; + lwork = (LAPACK_INT)wkopt; work = THTensor_(newWithSize1d)(lwork); THLapack_(geev)('N', jobvr[0], n, THTensor_(data)(a), lda, THTensor_(data)(wr), THTensor_(data)(wi), @@ -354,7 +354,7 @@ void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional"); THArgCheck(a->size[0] == a->size[1], 1,"A should be square"); - int n, lda, lwork, info; + LAPACK_INT n, lda, lwork, info; THTensor *work; real wkopt; @@ -363,7 +363,7 @@ void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz rv__ = THTensor_(cloneColumnMajor)(rv_, a); - n = rv__->size[0]; + n = (LAPACK_INT)rv__->size[0]; lda = n; THTensor_(resize1d)(re_,n); @@ -372,7 +372,7 @@ void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a, const char *jobz /* get optimal workspace size */ THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda, THTensor_(data)(re_), &wkopt, -1, &info); - lwork = (int)wkopt; + lwork = (LAPACK_INT)wkopt; work = THTensor_(newWithSize1d)(lwork); THLapack_(syev)(jobz[0], uplo[0], n, THTensor_(data)(rv__), lda, THTensor_(data)(re_), THTensor_(data)(work), lwork, &info); @@ -400,7 +400,7 @@ void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra if (a == NULL) a = ra_; THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional"); - int k,m, n, lda, ldu, ldvt, lwork, info; + LAPACK_INT k,m, n, lda, ldu, ldvt, lwork, info; THTensor *work; THTensor *rvf_ = THTensor_(new)(); real wkopt; @@ -412,8 +412,8 @@ void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra ra__ = THTensor_(cloneColumnMajor)(ra_, a); - m = ra__->size[0]; - n = ra__->size[1]; + m = (LAPACK_INT)ra__->size[0]; + n = (LAPACK_INT)ra__->size[1]; k = (m < n ? m : n); lda = m; @@ -441,7 +441,7 @@ void THTensor_(gesvd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra ldu, THTensor_(data)(rv__), ldvt, &wkopt, -1, &info); - lwork = (int)wkopt; + lwork = (LAPACK_INT)wkopt; work = THTensor_(newWithSize1d)(lwork); THLapack_(gesvd)(jobu[0],jobu[0], m,n,THTensor_(data)(ra__),lda, @@ -483,42 +483,42 @@ void THTensor_(getri)(THTensor *ra_, THTensor *a) THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional"); THArgCheck(a->size[0] == a->size[1], 1, "A should be square"); - int m, n, lda, info, lwork; + LAPACK_INT m, n, lda, info, lwork; real wkopt; - THIntTensor *ipiv; + LAPACK_INT *ipiv; THTensor *work; THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); - m = ra__->size[0]; - n = ra__->size[1]; + m = (LAPACK_INT)ra__->size[0]; + n = (LAPACK_INT)ra__->size[1]; lda = m; - ipiv = THIntTensor_newWithSize1d((long)m); + ipiv = (LAPACK_INT*) THAlloc(m * sizeof(LAPACK_INT)); /* Run LU */ - THLapack_(getrf)(n, n, THTensor_(data)(ra__), lda, THIntTensor_data(ipiv), &info); + THLapack_(getrf)(n, n, THTensor_(data)(ra__), lda, ipiv, &info); THLapackCheckWithCleanup("Lapack Error %s : U(%d,%d) is 0, U is singular", THCleanup( THTensor_(free)(ra__); - THIntTensor_free(ipiv);), + THFree(ipiv);), "getrf", info, info); /* Run inverse */ - THLapack_(getri)(n, THTensor_(data)(ra__), lda, THIntTensor_data(ipiv), &wkopt, -1, &info); - lwork = (int)wkopt; + THLapack_(getri)(n, THTensor_(data)(ra__), lda, ipiv, &wkopt, -1, &info); + lwork = (LAPACK_INT)wkopt; work = THTensor_(newWithSize1d)(lwork); - THLapack_(getri)(n, THTensor_(data)(ra__), lda, THIntTensor_data(ipiv), THTensor_(data)(work), lwork, &info); + THLapack_(getri)(n, THTensor_(data)(ra__), lda, ipiv, THTensor_(data)(work), lwork, &info); THLapackCheckWithCleanup("Lapack Error %s : U(%d,%d) is 0, U is singular", THCleanup( THTensor_(free)(ra__); THTensor_(free)(work); - THIntTensor_free(ipiv);), + THFree(ipiv);), "getri", info, info); THTensor_(freeCopyTo)(ra__, ra_); THTensor_(free)(work); - THIntTensor_free(ipiv); + THFree(ipiv); } void THTensor_(clearUpLoTriangle)(THTensor *a, const char *uplo) @@ -593,12 +593,12 @@ void THTensor_(potrf)(THTensor *ra_, THTensor *a, const char *uplo) THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional"); THArgCheck(a->size[0] == a->size[1], 1, "A should be square"); - int n, lda, info; + LAPACK_INT n, lda, info; THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); - n = ra__->size[0]; + n = (LAPACK_INT)ra__->size[0]; lda = n; /* Run Factorization */ @@ -631,15 +631,15 @@ void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo) free_b = 1; } - int n, nrhs, lda, ldb, info; + LAPACK_INT n, nrhs, lda, ldb, info; THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS ra__ = THTensor_(cloneColumnMajor)(NULL, a); rb__ = THTensor_(cloneColumnMajor)(rb_, b); - n = (int)ra__->size[0]; - nrhs = (int)rb__->size[1]; + n = (LAPACK_INT)ra__->size[0]; + nrhs = (LAPACK_INT)rb__->size[1]; lda = n; ldb = n; @@ -665,12 +665,12 @@ void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo) THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional"); THArgCheck(a->size[0] == a->size[1], 1, "A should be square"); - int n, lda, info; + LAPACK_INT n, lda, info; THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); - n = ra__->size[0]; + n = (LAPACK_INT)ra__->size[0]; lda = n; /* Run inverse */ @@ -703,32 +703,58 @@ void THTensor_(pstrf)(THTensor *ra_, THIntTensor *rpiv_, THTensor *a, const char THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional"); THArgCheck(a->size[0] == a->size[1], 1, "A should be square"); - int n = a->size[0]; + LAPACK_INT n = a->size[0]; THTensor *ra__ = THTensor_(cloneColumnMajor)(ra_, a); THIntTensor_resize1d(rpiv_, n); + LAPACK_INT *t_rp; + if (sizeof(LAPACK_INT) == sizeof(int)) + t_rp = (LAPACK_INT*)THIntTensor_data(rpiv_); + else + t_rp = (LAPACK_INT*)THAlloc(n * sizeof(LAPACK_INT)); + // Allocate working tensor THTensor *work = THTensor_(newWithSize1d)(2 * n); // Run Cholesky factorization - int lda = n; - int rank, info; + LAPACK_INT lda = n; + LAPACK_INT rank, info; THLapack_(pstrf)(uplo[0], n, THTensor_(data)(ra__), lda, - THIntTensor_data(rpiv_), &rank, tol, + t_rp, &rank, tol, THTensor_(data)(work), &info); - THLapackCheckWithCleanup("Lapack Error %s : matrix is rank deficient or not positive semidefinite", - THCleanup( - THTensor_(free)(ra__); - THTensor_(free)(work);), - "pstrf", info,""); + if (sizeof(LAPACK_INT) == sizeof(int)) + { + THLapackCheckWithCleanup("Lapack Error %s : matrix is rank deficient or not positive semidefinite", + THCleanup( + THTensor_(free)(ra__); + THTensor_(free)(work);), + "pstrf", info,""); + + } + else + { + THLapackCheckWithCleanup("Lapack Error %s : matrix is rank deficient or not positive semidefinite", + THCleanup( + THFree(t_rp); + THTensor_(free)(ra__); + THTensor_(free)(work);), + "pstrf", info,""); + + // copy back to int tensor + int *pdst = THIntTensor_data(rpiv_); + LAPACK_INT *psrc = t_rp; + for (int i = 0; i < n; i++) *pdst++ = (int)*psrc++; + THFree(t_rp); + } THTensor_(clearUpLoTriangle)(ra__, uplo); THTensor_(freeCopyTo)(ra__, ra_); THTensor_(free)(work); + } /* @@ -793,21 +819,21 @@ void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a) /* Prepare the input for LAPACK, making a copy if necessary. */ ra__ = THTensor_(cloneColumnMajor)(ra_, a); - int m = ra__->size[0]; - int n = ra__->size[1]; + LAPACK_INT m = (LAPACK_INT)ra__->size[0]; + LAPACK_INT n = (LAPACK_INT)ra__->size[1]; int k = (m < n ? m : n); int lda = m; THTensor_(resize1d)(rtau_, k); /* Dry-run to query the suggested size of the workspace. */ - int info = 0; + LAPACK_INT info = 0; real wkopt = 0; THLapack_(geqrf)(m, n, THTensor_(data)(ra__), lda, THTensor_(data)(rtau_), &wkopt, -1, &info); /* Allocate the workspace and call LAPACK to do the real work. */ - int lwork = (int)wkopt; + int lwork = (LAPACK_INT)wkopt; THTensor *work = THTensor_(newWithSize1d)(lwork); THLapack_(geqrf)(m, n, THTensor_(data)(ra__), lda, THTensor_(data)(rtau_), @@ -847,20 +873,20 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau) THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); - int m = ra__->size[0]; - int n = ra__->size[1]; - int k = tau->size[0]; + LAPACK_INT m = (LAPACK_INT)ra__->size[0]; + LAPACK_INT n = (LAPACK_INT)ra__->size[1]; + LAPACK_INT k = (LAPACK_INT)tau->size[0]; int lda = m; /* Dry-run to query the suggested size of the workspace. */ - int info = 0; + LAPACK_INT info = 0; real wkopt = 0; THLapack_(orgqr)(m, k, k, THTensor_(data)(ra__), lda, THTensor_(data)(tau), &wkopt, -1, &info); /* Allocate the workspace and call LAPACK to do the real work. */ - int lwork = (int)wkopt; + int lwork = (LAPACK_INT)wkopt; THTensor *work = THTensor_(newWithSize1d)(lwork); THLapack_(orgqr)(m, k, k, THTensor_(data)(ra__), lda, THTensor_(data)(tau), @@ -901,10 +927,10 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, c); - int m = c->size[0]; - int n = c->size[1]; - int k = tau->size[0]; - int lda; + LAPACK_INT m = (LAPACK_INT)c->size[0]; + LAPACK_INT n = (LAPACK_INT)c->size[1]; + LAPACK_INT k = (LAPACK_INT)tau->size[0]; + LAPACK_INT lda; if (*side == 'L') { lda = m; @@ -913,17 +939,17 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co { lda = n; } - int ldc = m; + LAPACK_INT ldc = m; /* Dry-run to query the suggested size of the workspace. */ - int info = 0; + LAPACK_INT info = 0; real wkopt = 0; THLapack_(ormqr)(side[0], trans[0], m, n, k, THTensor_(data)(a), lda, THTensor_(data)(tau), THTensor_(data)(ra__), ldc, &wkopt, -1, &info); /* Allocate the workspace and call LAPACK to do the real work. */ - int lwork = (int)wkopt; + int lwork = (LAPACK_INT)wkopt; THTensor *work = THTensor_(newWithSize1d)(lwork); THLapack_(ormqr)(side[0], trans[0], m, n, k, THTensor_(data)(a), lda, THTensor_(data)(tau), THTensor_(data)(ra__), ldc, @@ -947,14 +973,14 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf THTensor_(copy)(ra_, a); } - int m = a->size[1]; - int n = a->size[2]; + LAPACK_INT m = (LAPACK_INT)a->size[1]; + LAPACK_INT n = (LAPACK_INT)a->size[2]; if (m != n) { THError("btrifact is only implemented for square matrices"); } long num_batches = THTensor_(size)(a, 0); THTensor *ra__; - int lda; + LAPACK_INT lda; if (ra_->stride[1] == 1) { // column ordered, what BLAS wants @@ -973,11 +999,21 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf THTensor *rai = THTensor_(new)(); THIntTensor *rpivoti = THIntTensor_new(); - int info = 0; - int *info_ptr = &info; + LAPACK_INT *t_rp; + if (sizeof(LAPACK_INT) != sizeof(int)) + t_rp = (LAPACK_INT*)THAlloc(n * sizeof(LAPACK_INT)); + + LAPACK_INT info = 0; + LAPACK_INT *info_ptr = &info, *t_inf; if (rinfo_) { THIntTensor_resize1d(rinfo_, num_batches); - info_ptr = THIntTensor_data(rinfo_); + if (sizeof(LAPACK_INT) != sizeof(int)) + { + t_inf = (LAPACK_INT*)THAlloc(num_batches * sizeof(LAPACK_INT)); + info_ptr = t_inf; + } + else + info_ptr = (LAPACK_INT*)THIntTensor_data(rinfo_); } THIntTensor_resize2d(rpivots_, num_batches, n); @@ -988,8 +1024,19 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf THTensor_(select)(rai, ra__, 0, batch); THIntTensor_select(rpivoti, rpivots_, 0, batch); + if (sizeof(LAPACK_INT) == sizeof(int)) + t_rp = (LAPACK_INT*)THIntTensor_data(rpivoti); + THLapack_(getrf)(n, n, THTensor_(data)(rai), lda, - THIntTensor_data(rpivoti), info_ptr); + t_rp, info_ptr); + + if (sizeof(LAPACK_INT) != sizeof(int)) + { + int *pdst = THIntTensor_data(rpivoti); + LAPACK_INT *psrc = t_rp; + for (int i = 0; i < n; i++) *pdst++ = (int)*psrc++; + } + if (rinfo_) { info_ptr++; } else if (info != 0) { @@ -1001,6 +1048,17 @@ void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinf THTensor_(free)(rai); THIntTensor_free(rpivoti); + if (sizeof(LAPACK_INT) != sizeof(int)) + { + if (rinfo_) { + int *pdst = THIntTensor_data(rinfo_); + LAPACK_INT *psrc = t_inf; + for (int i = 0; i < n; i++) *pdst++ = (int)*psrc++; + THFree(t_inf); + } + THFree(t_rp); + } + if (ra__ != ra_) { THTensor_(freeCopyTo)(ra__, ra_); } @@ -1029,10 +1087,10 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor } long num_batches = atf->size[0]; - long n = atf->size[1]; - int nrhs = rb_->nDimension > 2 ? rb_->size[2] : 1; + LAPACK_INT n = (LAPACK_INT)atf->size[1]; + LAPACK_INT nrhs = (LAPACK_INT)(rb_->nDimension > 2 ? rb_->size[2] : 1); - int lda, ldb; + LAPACK_INT lda, ldb; THTensor *atf_; THTensor *rb__; @@ -1084,19 +1142,34 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor THError("Error: rpivots_ is not contiguous."); } + LAPACK_INT *t_rp; + if (sizeof(LAPACK_INT) != sizeof(int)) + t_rp = (LAPACK_INT*)THAlloc(n * sizeof(LAPACK_INT)); + for (long batch = 0; batch < num_batches; ++batch) { THTensor_(select)(ai, atf_, 0, batch); THTensor_(select)(rbi, rb__, 0, batch); THIntTensor_select(pivoti, pivots, 0, batch); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) - int info; + + if (sizeof(LAPACK_INT) == sizeof(int)) + t_rp = (LAPACK_INT*)THIntTensor_data(pivoti); + + LAPACK_INT info; THLapack_(getrs)('N', n, nrhs, THTensor_(data)(ai), lda, - THIntTensor_data(pivoti), THTensor_(data)(rbi), + t_rp, THTensor_(data)(rbi), ldb, &info); if (info != 0) { THError("Error: Nonzero info."); } + if (sizeof(LAPACK_INT) != sizeof(int)) + { + int *pdst = THIntTensor_data(pivoti); + LAPACK_INT *psrc = t_rp; + for (int i = 0; i < n; i++) *pdst++ = (int)*psrc++; + } + #else THError("Unimplemented"); #endif @@ -1106,6 +1179,9 @@ void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor THTensor_(free)(rbi); THIntTensor_free(pivoti); + if (sizeof(LAPACK_INT) != sizeof(int)) + THFree(t_rp); + if (atf_ != atf) { THTensor_(free)(atf_); } diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 29894e20..68e208ce 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -638,11 +638,11 @@ void THTensor_(div)(THTensor *r_, THTensor *t, real value) void THTensor_(lshift)(THTensor *r_, THTensor *t, real value) { #if defined(TH_REAL_IS_FLOAT) - return THTensor_(mul)(r_, t, powf(2, value)); + THTensor_(mul)(r_, t, powf(2, value)); #elif defined(TH_REAL_IS_DOUBLE) - return THTensor_(mul)(r_, t, pow(2, value)); + THTensor_(mul)(r_, t, pow(2, value)); #elif defined(TH_REAL_IS_HALF) - return THError("lshift is not supported for torch.HalfTensor"); + THError("lshift is not supported for torch.HalfTensor"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -673,11 +673,11 @@ void THTensor_(lshift)(THTensor *r_, THTensor *t, real value) void THTensor_(rshift)(THTensor *r_, THTensor *t, real value) { #if defined(TH_REAL_IS_FLOAT) - return THTensor_(div)(r_, t, powf(2, value)); + THTensor_(div)(r_, t, powf(2, value)); #elif defined(TH_REAL_IS_DOUBLE) - return THTensor_(div)(r_, t, pow(2, value)); + THTensor_(div)(r_, t, pow(2, value)); #elif defined(TH_REAL_IS_HALF) - return THError("rshift is not supported for torch.HalfTensor"); + THError("rshift is not supported for torch.HalfTensor"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -764,7 +764,7 @@ void THTensor_(remainder)(THTensor *r_, THTensor *t, real value) void THTensor_(bitand)(THTensor *r_, THTensor *t, real value) { #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - return THError("bitand is only supported for integer type tensors"); + THError("bitand is only supported for integer type tensors"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -787,7 +787,7 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, real value) void THTensor_(bitor)(THTensor *r_, THTensor *t, real value) { #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - return THError("bitor is only supported for integer type tensors"); + THError("bitor is only supported for integer type tensors"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -810,7 +810,7 @@ void THTensor_(bitor)(THTensor *r_, THTensor *t, real value) void THTensor_(bitxor)(THTensor *r_, THTensor *t, real value) { #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - return THError("bitxor is only supported for integer type tensors"); + THError("bitxor is only supported for integer type tensors"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -1045,7 +1045,7 @@ void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src) void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src) { #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - return THError("cbitand is only supported for integer type tensors"); + THError("cbitand is only supported for integer type tensors"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -1070,7 +1070,7 @@ void THTensor_(cbitand)(THTensor *r_, THTensor *t, THTensor *src) void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src) { #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - return THError("cbitor is only supported for integer type tensors"); + THError("cbitor is only supported for integer type tensors"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) && @@ -1095,7 +1095,7 @@ void THTensor_(cbitor)(THTensor *r_, THTensor *t, THTensor *src) void THTensor_(cbitxor)(THTensor *r_, THTensor *t, THTensor *src) { #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) - return THError("cbitxor is only supported for integer type tensors"); + THError("cbitxor is only supported for integer type tensors"); #else THTensor_(resizeAs)(r_, t); if (THTensor_(isContiguous)(r_) &&