Skip to content

Commit

Permalink
run checks and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oojo12 committed Sep 28, 2024
1 parent 07f79ab commit 7e221c9
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 31 deletions.
11 changes: 8 additions & 3 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2141,10 +2141,15 @@ mod tests {
let (values_tensor, indices_tensor) = model.forward(input);

// expected results
let expected_values_tensor = TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]);
let expected_values_tensor =
TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]);
let expected_indices_tensor = TensorData::from([[3, 2, 1], [3, 2, 1]]);

values_tensor.to_data().assert_eq(&expected_values_tensor, true);
indices_tensor.to_data().assert_eq(&expected_indices_tensor, true);
values_tensor
.to_data()
.assert_eq(&expected_values_tensor, true);
indices_tensor
.to_data()
.assert_eq(&expected_indices_tensor, true);
}
}
12 changes: 7 additions & 5 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,10 +811,12 @@ pub fn top_k_config(node: &Node) -> TopKConfig {
};

let k = match node.inputs.get(1) {
Some(k_tensor) => {
k_tensor.clone().value.expect("Expecting K tensor to have a value.").into_i64s()[0]
}
_ => extract_attr_value_i64(node, "k")
Some(k_tensor) => k_tensor
.clone()
.value
.expect("Expecting K tensor to have a value.")
.into_i64s()[0],
_ => extract_attr_value_i64(node, "k"),
};

let mut axis: i64 = extract_attr_value_i64(node, "axis");
Expand All @@ -826,7 +828,7 @@ pub fn top_k_config(node: &Node) -> TopKConfig {

let largest = match node.attrs.get("largest") {
Some(val) => val.clone().into_i64(),
_ => 1
_ => 1,
};

TopKConfig::new(axis as usize, k as usize, largest as usize)
Expand Down
17 changes: 10 additions & 7 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,33 +735,36 @@ where
} else {
self.sort(dim).select(dim, k_indices)
}
},
_ => {
self.sort_descending(dim).select(dim, k_indices)
}
_ => self.sort_descending(dim).select(dim, k_indices),
}
}

/// Returns the `k` largest elements of the given input tensor along a given dimension.
/// Also returns the indices.
pub fn topk_with_indices(self, k: usize, dim: usize, largest: Option<usize>) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
pub fn topk_with_indices(
self,
k: usize,
dim: usize,
largest: Option<usize>,
) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
let k_indices = Tensor::arange(0..k as i64, &self.device());
match largest {
Some(largest) => {
if largest == 1 {
let (values, indices) = self.sort_with_indices(dim);
let (values, indices) = self.sort_descending_with_indices(dim);
(
values.select(dim, k_indices.clone()),
indices.select(dim, k_indices),
)
} else {
let (values, indices) = self.sort_descending_with_indices(dim);
let (values, indices) = self.sort_with_indices(dim);
(
values.select(dim, k_indices.clone()),
indices.select(dim, k_indices),
)
}
},
}
_ => {
let (values, indices) = self.sort_descending_with_indices(dim);
(
Expand Down
18 changes: 10 additions & 8 deletions crates/burn-tensor/src/tests/ops/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ mod tests {
let expected = TensorData::from([5., 4., 3.]);

values.into_data().assert_approx_eq(&expected, 5);

// Float
// smallest
let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]);
Expand Down Expand Up @@ -72,7 +71,7 @@ mod tests {

// smallest
let tensor =
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]);
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]);

let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0));
let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]);
Expand All @@ -84,7 +83,8 @@ mod tests {
// largest
let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);

let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1));
let (values, indices) =
tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1));

let values_expected = TensorData::from([5, 4, 3]);
values.into_data().assert_eq(&values_expected, false);
Expand All @@ -95,7 +95,8 @@ mod tests {
// smallest
let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);

let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0));
let (values, indices) =
tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0));

let values_expected = TensorData::from([1, 2, 3]);
values.into_data().assert_eq(&values_expected, false);
Expand All @@ -108,7 +109,8 @@ mod tests {
let tensor =
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);

let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1));
let (values, indices) =
tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1));

let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);

Expand All @@ -120,9 +122,10 @@ mod tests {

// smallest
let tensor =
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);

let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0));
let (values, indices) =
tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0));

let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]);

Expand All @@ -131,6 +134,5 @@ mod tests {
let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]);

indices.into_data().assert_eq(&indices_expected, false);

}
}
79 changes: 71 additions & 8 deletions crates/burn-tensor/src/tests/quantization/ops/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,27 @@ mod tests {
);
let tensor = TestTensor::<1>::from_data(data, &Default::default());

let values = tensor.topk(3, /*dim*/ 0);
// largest
let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(1));
let expected = TensorData::from([5., 4., 3.]);

values
.dequantize()
.into_data()
.assert_approx_eq(&expected, 3);

// smallest
let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(0));
let expected = TensorData::from([1., 2., 3.]);

values
.dequantize()
.into_data()
.assert_approx_eq(&expected, 3);
}

#[test]
fn test_topk() {
fn test_topk_3d() {
// Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]
let data = TensorData::quantized(
vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70],
Expand All @@ -33,9 +43,20 @@ mod tests {
);
let tensor = TestTensor::<3>::from_data(data, &Default::default());

let values = tensor.topk(2, /*dim*/ 2);
// largest
let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(1));
let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);

// Precision 1 to approximate de/quantization errors
values
.dequantize()
.into_data()
.assert_approx_eq(&expected, 1);

// smallest
let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(0));
let expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]);

// Precision 1 to approximate de/quantization errors
values
.dequantize()
Expand All @@ -44,8 +65,7 @@ mod tests {
}

#[test]
fn test_topk_with_indices() {
// 1D
fn test_topk_with_indices_1d() {
// Quantized [1.0, 2.0, 3.0, 4.0, 5.0]
let data = TensorData::quantized(
vec![-77i8, -26, 25, 76, 127],
Expand All @@ -54,7 +74,11 @@ mod tests {
);
let tensor = TestTensor::<1>::from_data(data, &Default::default());

let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0);
// largest
let (values, indices) =
tensor
.clone()
.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1));

let values_expected = TensorData::from([5., 4., 3.]);
values
Expand All @@ -65,7 +89,24 @@ mod tests {
let indices_expected = TensorData::from([4, 3, 2]);
indices.into_data().assert_eq(&indices_expected, false);

// 3D
// smallest
let (values, indices) =
tensor
.clone()
.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0));

let values_expected = TensorData::from([1., 2., 3.]);
values
.dequantize()
.into_data()
.assert_eq(&values_expected, false);

let indices_expected = TensorData::from([0, 1, 2]);
indices.into_data().assert_eq(&indices_expected, false);
}

#[test]
fn test_topk_with_indices_3d() {
// Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]
let data = TensorData::quantized(
vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70],
Expand All @@ -74,7 +115,11 @@ mod tests {
);
let tensor = TestTensor::<3>::from_data(data, &Default::default());

let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2);
// largest
let (values, indices) =
tensor
.clone()
.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1));

let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);

Expand All @@ -87,5 +132,23 @@ mod tests {
let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]);

indices.into_data().assert_eq(&indices_expected, false);

// smallest
let (values, indices) =
tensor
.clone()
.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0));

let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]);

// Precision 1 to approximate de/quantization errors
values
.dequantize()
.into_data()
.assert_approx_eq(&values_expected, 1);

let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]);

indices.into_data().assert_eq(&indices_expected, false);
}
}

0 comments on commit 7e221c9

Please sign in to comment.