Skip to content

Commit

Permalink
Make it possible to change index base in TH
Browse files Browse the repository at this point in the history
Not every language has 1-based indexing...
  • Loading branch information
apaszke committed Jul 20, 2016
1 parent 7c740d5 commit be269db
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 39 deletions.
4 changes: 4 additions & 0 deletions lib/TH/THGeneral.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
# define M_PI 3.14159265358979323846
#endif

#ifndef TH_INDEX_BASE
#define TH_INDEX_BASE 1
#endif

TH_API double THLog1p(const double x);
TH_API void _THError(const char *file, const int line, const char *fmt, ...);
TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...);
Expand Down
80 changes: 41 additions & 39 deletions lib/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
real *tensor_data, *src_data;

THArgCheck(index->nDimension == 1, 3, "Index is supposed to be a vector");
THArgCheck(dim < src->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim+1);
THArgCheck(dim < src->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
THArgCheck(src->nDimension > 0,2,"Source tensor is empty");

numel = THLongTensor_nElement(index);
Expand All @@ -149,9 +149,9 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
long rowsize = THTensor_(nElement)(src) / src->size[0];

// check that the indices are within range
long max = src->size[0];
long max = src->size[0] - 1 + TH_INDEX_BASE;
for (i=0; i<numel; i++) {
if (index_data[i] < 1 || index_data[i] > max) {
if (index_data[i] < TH_INDEX_BASE || index_data[i] > max) {
THLongTensor_free(index);
THError("index out of range");
}
Expand All @@ -160,17 +160,17 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
if (src->nDimension == 1) {
#pragma omp parallel for if(numel > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<numel; i++)
tensor_data[i] = src_data[index_data[i]-1];
tensor_data[i] = src_data[index_data[i] - TH_INDEX_BASE];
} else {
#pragma omp parallel for if(numel*rowsize > TH_OMP_OVERHEAD_THRESHOLD) private(i)
for (i=0; i<numel; i++)
memcpy(tensor_data + i*rowsize, src_data + (index_data[i]-1)*rowsize, rowsize*sizeof(real));
memcpy(tensor_data + i*rowsize, src_data + (index_data[i] - TH_INDEX_BASE)*rowsize, rowsize*sizeof(real));
}
}
else if (src->nDimension == 1)
{
for (i=0; i<numel; i++)
THTensor_(set1d)(tensor,i,THTensor_(get1d)(src,index_data[i]-1));
THTensor_(set1d)(tensor,i,THTensor_(get1d)(src,index_data[i] - TH_INDEX_BASE));
}
else
{
Expand All @@ -179,7 +179,7 @@ void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTens
tSlice = THTensor_(new)();
sSlice = THTensor_(new)();
THTensor_(select)(tSlice, tensor, dim, i);
THTensor_(select)(sSlice, src, dim, index_data[i]-1);
THTensor_(select)(sSlice, src, dim, index_data[i] - TH_INDEX_BASE);
THTensor_(copy)(tSlice, sSlice);
THTensor_(free)(tSlice);
THTensor_(free)(sSlice);
Expand All @@ -197,7 +197,7 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens

numel = THLongTensor_nElement(index);
THArgCheck(index->nDimension == 1, 3, "Index is supposed to be a vector");
THArgCheck(dim < src->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim+1);
THArgCheck(dim < src->nDimension, 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
THArgCheck(numel == src->size[dim],4,"Number of indices should be equal to source:size(dim)");

index = THLongTensor_newContiguous(index);
Expand All @@ -210,7 +210,7 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens

for (i=0; i<numel; i++)
{
THTensor_(select)(tSlice, tensor, dim, index_data[i]-1);
THTensor_(select)(tSlice, tensor, dim, index_data[i] - TH_INDEX_BASE);
THTensor_(select)(sSlice, src, dim, i);
THTensor_(copy)(tSlice, sSlice);
}
Expand All @@ -222,7 +222,7 @@ void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTens
{
for (i=0; i<numel; i++)
{
THTensor_(set1d)(tensor,index_data[i]-1,THTensor_(get1d)(src,i));
THTensor_(set1d)(tensor, index_data[i] - TH_INDEX_BASE, THTensor_(get1d)(src,i));
}
}
THLongTensor_free(index);
Expand All @@ -236,20 +236,20 @@ void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTenso

numel = THLongTensor_nElement(index);
THArgCheck(index->nDimension == 1, 3, "Index is supposed to be a vector");
THArgCheck(dim < src->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim+1);
THArgCheck(dim < src->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
THArgCheck(numel == src->size[dim],4,"Number of indices should be equal to source:size(dim)");

index = THLongTensor_newContiguous(index);
index_data = THLongTensor_data(index);

if (tensor->nDimension > 1 )
if (tensor->nDimension > 1)
{
tSlice = THTensor_(new)();
sSlice = THTensor_(new)();

for (i=0; i<numel; i++)
{
THTensor_(select)(tSlice, tensor, dim, index_data[i]-1);
THTensor_(select)(tSlice, tensor, dim, index_data[i] - TH_INDEX_BASE);
THTensor_(select)(sSlice, src, dim, i);
THTensor_(cadd)(tSlice, tSlice, 1.0, sSlice);
}
Expand All @@ -261,7 +261,9 @@ void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTenso
{
for (i=0; i<numel; i++)
{
THTensor_(set1d)(tensor,index_data[i]-1,THTensor_(get1d)(src,i) + THTensor_(get1d)(tensor,index_data[i]-1));
THTensor_(set1d)(tensor,
index_data[i] - TH_INDEX_BASE,
THTensor_(get1d)(src,i) + THTensor_(get1d)(tensor,index_data[i] - TH_INDEX_BASE));
}
}
THLongTensor_free(index);
Expand All @@ -275,23 +277,23 @@ void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, real v

numel = THLongTensor_nElement(index);
THArgCheck(index->nDimension == 1, 3, "Index is supposed to be a vector");
THArgCheck(dim < tensor->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim+1);
THArgCheck(dim < tensor->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);

index = THLongTensor_newContiguous(index);
index_data = THLongTensor_data(index);

for (i=0; i<numel; i++)
{
if (tensor->nDimension > 1 )
if (tensor->nDimension > 1)
{
tSlice = THTensor_(new)();
THTensor_(select)(tSlice, tensor,dim,index_data[i]-1);
THTensor_(select)(tSlice, tensor,dim,index_data[i] - TH_INDEX_BASE);
THTensor_(fill)(tSlice, val);
THTensor_(free)(tSlice);
}
else
{
THTensor_(set1d)(tensor,index_data[i]-1,val);
THTensor_(set1d)(tensor, index_data[i] - TH_INDEX_BASE, val);
}
}
THLongTensor_free(index);
Expand All @@ -313,12 +315,12 @@ void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *i
for (i = 0; i < elems_per_row; ++i)
{
idx = *(index_data + i*index_stride);
if (idx < 1 || idx > src_size)
if (idx < TH_INDEX_BASE || idx >= src_size + TH_INDEX_BASE)
{
THFree(TH_TENSOR_DIM_APPLY_counter);
THError("Invalid index in gather");
}
*(tensor_data + i*tensor_stride) = src_data[(idx - 1) * src_stride];
*(tensor_data + i*tensor_stride) = src_data[(idx - TH_INDEX_BASE) * src_stride];
})
}

Expand All @@ -338,12 +340,12 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor
for (i = 0; i < elems_per_row; ++i)
{
idx = *(index_data + i*index_stride);
if (idx < 1 || idx > tensor_size)
if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE)
{
THFree(TH_TENSOR_DIM_APPLY_counter);
THError("Invalid index in scatter");
}
tensor_data[(idx - 1) * tensor_stride] = *(src_data + i*src_stride);
tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] = *(src_data + i*src_stride);
})
}

Expand All @@ -361,12 +363,12 @@ void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real
for (i = 0; i < elems_per_row; ++i)
{
idx = *(index_data + i*index_stride);
if (idx < 1 || idx > tensor_size)
if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE)
{
THFree(TH_TENSOR_DIM_APPLY_counter);
THError("Invalid index in scatter");
}
tensor_data[(idx - 1) * tensor_stride] = val;
tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] = val;
})
}

Expand Down Expand Up @@ -1061,7 +1063,7 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
long i;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand Down Expand Up @@ -1097,7 +1099,7 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
long i;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand Down Expand Up @@ -1130,7 +1132,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension)
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand All @@ -1150,7 +1152,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension)
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand All @@ -1169,7 +1171,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension)
void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension)
{
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);

THTensor_(resizeAs)(r_, t);

Expand All @@ -1186,7 +1188,7 @@ void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension)
void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension)
{
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);

THTensor_(resizeAs)(r_, t);

Expand Down Expand Up @@ -1273,9 +1275,9 @@ void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension)
}

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(a), 3, "dimension %d out of range",
dimension+1);
dimension + TH_INDEX_BASE);
THArgCheck(THTensor_(size)(a, dimension) == 3, 3, "dimension %d does not have size 3",
dimension+1);
dimension + TH_INDEX_BASE);

THTensor_(resizeAs)(r_, a);

Expand Down Expand Up @@ -1650,7 +1652,7 @@ static void THTensor_(quicksortdescend)(real *arr, long *idx, long elements, lon
void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder)
{
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d",
dimension+1);
dimension + TH_INDEX_BASE);

THTensor_(resizeAs)(rt_, t);
THTensor_(copy)(rt_, t);
Expand Down Expand Up @@ -1993,7 +1995,7 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
}

THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension+1);
THArgCheck(dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);

size = THLongStorage_newWithSize(ndim);
for(i = 0; i < ndim; i++)
Expand Down Expand Up @@ -2182,7 +2184,7 @@ void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension)
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand All @@ -2202,7 +2204,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag)
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand Down Expand Up @@ -2243,7 +2245,7 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int flag)
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand Down Expand Up @@ -2284,7 +2286,7 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension)
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
dimension+1);
dimension + TH_INDEX_BASE);

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand Down Expand Up @@ -2332,7 +2334,7 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, real value, int dimension,
THTensor *rowR, *rowS;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(src), 3, "invalid dimension %d",
dimension+1);
dimension + TH_INDEX_BASE);
THArgCheck(value > 0, 2, "non-positive-norm not supported");
THArgCheck(THTensor_(nDimension)(src) > 1, 1, "need at least 2 dimensions, got %d dimensions",
THTensor_(nDimension)(src));
Expand Down

0 comments on commit be269db

Please sign in to comment.