Skip to content

Commit

Permalink
Fix accepted k range in THTensor_kthvalue
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke committed Jul 20, 2016
1 parent be269db commit d849bc4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static const void* torch_istensorarray(lua_State *L, int idx)
lua_checkstack(L, 3);
lua_rawgeti(L, idx, 1);
tensor_idx = lua_gettop(L);
tensor_idx = lua_gettop(L);
tname = (torch_istensortype(L, luaT_typename(L, -1)));
lua_remove(L, tensor_idx);
return tname;
Expand Down Expand Up @@ -316,7 +316,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
{{name=Tensor, default=true, returned=true, method={default='nil'}},
{name=Tensor, method={default=1}},
{name=real}})

-- mod alias
wrap("mod",
cname("fmod"),
Expand Down Expand Up @@ -652,7 +652,7 @@ wrap("topk",
{{name=Tensor, default=true, returned=true},
{name="IndexTensor", default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="index"},
{name="long"},
{name="index", default=lastdim(3)}})

wrap("mode",
Expand Down
10 changes: 5 additions & 5 deletions lib/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,7 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
long t_size_dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");
THArgCheck(k >= 0 && k < t->size[dimension], 2, "selected index out of range");
THArgCheck(k > 0 && k <= t->size[dimension], 2, "selected index out of range");

dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
Expand All @@ -1828,9 +1828,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
temp__data[i] = t_data[i*t_stride];
for(i = 0; i < t_size_dim; i++)
tempi__data[i] = i;
THTensor_(quickselect)(temp__data, tempi__data, k, t_size_dim, 1);
*values__data = temp__data[k];
*indices__data = tempi__data[k];);
THTensor_(quickselect)(temp__data, tempi__data, k - 1, t_size_dim, 1);
*values__data = temp__data[k-1];
*indices__data = tempi__data[k-1];);

THTensor_(free)(temp_);
THLongTensor_free(tempi_);
Expand All @@ -1845,7 +1845,7 @@ void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, i
t_size_dim = THTensor_(size)(t, dimension);
k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */

THTensor_(kthvalue)(values_, indices_, t, k, dimension);
THTensor_(kthvalue)(values_, indices_, t, k+1, dimension);
}

void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, long k, int dim, int dir, int sorted)
Expand Down

0 comments on commit d849bc4

Please sign in to comment.