diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 147080a31e..cfbce8830d 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -2141,10 +2141,15 @@ mod tests { let (values_tensor, indices_tensor) = model.forward(input); // expected results - let expected_values_tensor = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); + let expected_values_tensor = + TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); let expected_indices_tensor = TensorData::from([[3, 2, 1], [3, 2, 1]]); - values_tensor.to_data().assert_eq(&expected_values_tensor, true); - indices_tensor.to_data().assert_eq(&expected_indices_tensor, true); + values_tensor + .to_data() + .assert_eq(&expected_values_tensor, true); + indices_tensor + .to_data() + .assert_eq(&expected_indices_tensor, true); } } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 1290effdd7..bda0d745bf 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -811,10 +811,12 @@ pub fn top_k_config(node: &Node) -> TopKConfig { }; let k = match node.inputs.get(1) { - Some(k_tensor) => { - k_tensor.clone().value.expect("Expecting K tensor to have a value.").into_i64s()[0] - } - _ => extract_attr_value_i64(node, "k") + Some(k_tensor) => k_tensor + .clone() + .value + .expect("Expecting K tensor to have a value.") + .into_i64s()[0], + _ => extract_attr_value_i64(node, "k"), }; let mut axis: i64 = extract_attr_value_i64(node, "axis"); @@ -826,7 +828,7 @@ pub fn top_k_config(node: &Node) -> TopKConfig { let largest = match node.attrs.get("largest") { Some(val) => val.clone().into_i64(), - _ => 1 + _ => 1, }; TopKConfig::new(axis as usize, k as usize, largest as usize) diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 785caf1c27..77057e2f6d 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -735,33 +735,36 @@ where } else { self.sort(dim).select(dim, k_indices) } - }, - _ => { - self.sort_descending(dim).select(dim, k_indices) } + _ => self.sort_descending(dim).select(dim, k_indices), } } /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. - pub fn topk_with_indices(self, k: usize, dim: usize, largest: Option) -> (Tensor, Tensor) { + pub fn topk_with_indices( + self, + k: usize, + dim: usize, + largest: Option, + ) -> (Tensor, Tensor) { let k_indices = Tensor::arange(0..k as i64, &self.device()); match largest { Some(largest) => { if largest == 1 { - let (values, indices) = self.sort_with_indices(dim); + let (values, indices) = self.sort_descending_with_indices(dim); ( values.select(dim, k_indices.clone()), indices.select(dim, k_indices), ) } else { - let (values, indices) = self.sort_descending_with_indices(dim); + let (values, indices) = self.sort_with_indices(dim); ( values.select(dim, k_indices.clone()), indices.select(dim, k_indices), ) } - }, + } _ => { let (values, indices) = self.sort_descending_with_indices(dim); ( diff --git a/crates/burn-tensor/src/tests/ops/topk.rs b/crates/burn-tensor/src/tests/ops/topk.rs index 6d8f3a0784..59411f7bd8 100644 --- a/crates/burn-tensor/src/tests/ops/topk.rs +++ b/crates/burn-tensor/src/tests/ops/topk.rs @@ -30,7 +30,6 @@ mod tests { let expected = TensorData::from([5., 4., 3.]); values.into_data().assert_approx_eq(&expected, 5); - // Float // smallest let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); @@ -72,7 +71,7 @@ mod tests { // smallest let tensor = - TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0)); let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]); @@ -84,7 +83,8 @@ mod tests { // largest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); + let (values, indices) = + tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&values_expected, false); @@ -95,7 +95,8 @@ mod tests { // smallest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + let (values, indices) = + tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); let values_expected = TensorData::from([1, 2, 3]); values.into_data().assert_eq(&values_expected, false); @@ -108,7 +109,8 @@ mod tests { let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); + let (values, indices) = + tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -120,9 +122,10 @@ mod tests { // smallest let tensor = - TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + let (values, indices) = + tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); @@ -131,6 +134,5 @@ mod tests { let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); indices.into_data().assert_eq(&indices_expected, false); - } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/topk.rs b/crates/burn-tensor/src/tests/quantization/ops/topk.rs index 7913fb79e1..78b4900523 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/topk.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/topk.rs @@ -14,9 +14,19 @@ mod tests { ); let tensor = TestTensor::<1>::from_data(data, &Default::default()); - let values = tensor.topk(3, /*dim*/ 0); + // largest + let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5., 4., 3.]); + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + + // smallest + let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1., 2., 3.]); + values .dequantize() .into_data() @@ -24,7 +34,7 @@ mod tests { } #[test] - fn test_topk() { + fn test_topk_3d() { // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] let data = TensorData::quantized( vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], @@ -33,9 +43,20 @@ mod tests { ); let tensor = TestTensor::<3>::from_data(data, &Default::default()); - let values = tensor.topk(2, /*dim*/ 2); + // largest + let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + + // smallest + let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + // Precision 1 to approximate de/quantization errors values .dequantize() @@ -44,8 +65,7 @@ mod tests { } #[test] - fn test_topk_with_indices() { - // 1D + fn test_topk_with_indices_1d() { // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] let data = TensorData::quantized( vec![-77i8, -26, 25, 76, 127], @@ -54,7 +74,11 @@ mod tests { ); let tensor = TestTensor::<1>::from_data(data, &Default::default()); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); + // largest + let (values, indices) = + tensor + .clone() + .topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5., 4., 3.]); values @@ -65,7 +89,24 @@ mod tests { let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); - // 3D + // smallest + let (values, indices) = + tensor + .clone() + .topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + + let values_expected = TensorData::from([1., 2., 3.]); + values + .dequantize() + .into_data() + .assert_eq(&values_expected, false); + + let indices_expected = TensorData::from([0, 1, 2]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_topk_with_indices_3d() { // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] let data = TensorData::quantized( vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], @@ -74,7 +115,11 @@ mod tests { ); let tensor = TestTensor::<3>::from_data(data, &Default::default()); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); + // largest + let (values, indices) = + tensor + .clone() + .topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -87,5 +132,23 @@ mod tests { let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); + + // smallest + let (values, indices) = + tensor + .clone() + .topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + + let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); + + indices.into_data().assert_eq(&indices_expected, false); } }