From d849bc44e36ff6fe3826d44c31e99058d2f27c76 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 20 Jul 2016 14:10:06 -0400 Subject: [PATCH] Fix accepted k range in THTensor_kthvalue --- TensorMath.lua | 6 +++--- lib/TH/generic/THTensorMath.c | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/TensorMath.lua b/TensorMath.lua index 5a37b122..6b79237c 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -61,7 +61,7 @@ static const void* torch_istensorarray(lua_State *L, int idx) lua_checkstack(L, 3); lua_rawgeti(L, idx, 1); - tensor_idx = lua_gettop(L); + tensor_idx = lua_gettop(L); tname = (torch_istensortype(L, luaT_typename(L, -1))); lua_remove(L, tensor_idx); return tname; @@ -316,7 +316,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor", {{name=Tensor, default=true, returned=true, method={default='nil'}}, {name=Tensor, method={default=1}}, {name=real}}) - + -- mod alias wrap("mod", cname("fmod"), @@ -652,7 +652,7 @@ wrap("topk", {{name=Tensor, default=true, returned=true}, {name="IndexTensor", default=true, returned=true, noreadadd=true}, {name=Tensor}, - {name="index"}, + {name="long"}, {name="index", default=lastdim(3)}}) wrap("mode", diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 4338e3d7..47bd6ec9 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -1804,7 +1804,7 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, long t_size_dim; THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range"); - THArgCheck(k >= 0 && k < t->size[dimension], 2, "selected index out of range"); + THArgCheck(k > 0 && k <= t->size[dimension], 2, "selected index out of range"); dim = THTensor_(newSizeOf)(t); THLongStorage_set(dim, dimension, 1); @@ -1828,9 +1828,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, temp__data[i] = t_data[i*t_stride]; for(i = 0; i < t_size_dim; i++) tempi__data[i] = i; - THTensor_(quickselect)(temp__data, tempi__data, k, t_size_dim, 1); - *values__data = temp__data[k]; - *indices__data = tempi__data[k];); + THTensor_(quickselect)(temp__data, tempi__data, k - 1, t_size_dim, 1); + *values__data = temp__data[k-1]; + *indices__data = tempi__data[k-1];); THTensor_(free)(temp_); THLongTensor_free(tempi_); @@ -1845,7 +1845,7 @@ void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, i t_size_dim = THTensor_(size)(t, dimension); k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */ - THTensor_(kthvalue)(values_, indices_, t, k, dimension); + THTensor_(kthvalue)(values_, indices_, t, k+1, dimension); } void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, long k, int dim, int dir, int sorted)