diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 6319fedbe7..668a4da610 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -192,7 +192,7 @@ represent the corresponding Burn Op. | [TfIdfVectorizer][183] | ❌ | ❌ | | [ThresholdedRelu][184] | ❌ | ❌ | | [Tile][185] | ✅ | ✅ | -| [TopK][186] | ❌ | ✅ | +| [TopK][186] | ✅ | ✅ | | [Transpose][187] | ✅ | ✅ | | [Trilu][188] | ❌ | ✅ | | [Unique][189] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a7360e012d..ff4a034cbe 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -107,6 +107,7 @@ fn main() { .input("tests/sum/sum_int.onnx") .input("tests/tanh/tanh.onnx") .input("tests/tile/tile.onnx") + .input("tests/top_k/top_k.onnx") .input("tests/transpose/transpose.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index a4cd32b485..cfbce8830d 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -116,6 +116,7 @@ include_models!( sum_int, tanh, tile, + top_k_opset_1, transpose, unsqueeze, unsqueeze_opset11, @@ -128,7 +129,7 @@ mod tests { use super::*; - use burn::tensor::{Bool, Int, Shape, Tensor, TensorData}; + use burn::tensor::{cast::ToElement, Bool, Int, Shape, Tensor, TensorData}; use float_cmp::ApproxEq; @@ -2125,4 +2126,30 @@ mod tests { assert!(i_output.equal(i_expected).all().into_scalar()); assert!(b_output.equal(b_expected).all().into_scalar()); } + + #[test] + fn top_k_opset_1() { + // Initialize the model + let device = Default::default(); + let model = top_k_opset1::Model::::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], + &device, + ); + let (values_tensor, indices_tensor) = model.forward(input); + + // expected results + let expected_values_tensor = + TensorData::from([[4.0, 3.0, 2.to_f32()], [4.0, 3.0, 2.to_f32()]]); + let expected_indices_tensor = TensorData::from([[3, 2, 1], [3, 2, 1]]); + + values_tensor + .to_data() + .assert_eq(&expected_values_tensor, true); + indices_tensor + .to_data() + .assert_eq(&expected_indices_tensor, true); + } } diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx b/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx new file mode 100644 index 0000000000..4eb08a05c9 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/top_k/top_k.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k.py b/crates/burn-import/onnx-tests/tests/top_k/top_k.py new file mode 100644 index 0000000000..176c743875 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/top_k/top_k.py @@ -0,0 +1,75 @@ +import numpy as np +import onnx +from onnx import helper, TensorProto + +# Define the input tensor +X = np.array([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], dtype=np.float32) + +# Define the value of K +k = 3 +K = np.array([k], dtype=np.int64) +axis = 1 +new_dims = [X.shape[0], k] + +def create_model(op_set_version: int): + input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)] + + output_tensors = [ + helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), + helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) + ] + + # Create the TopK node + if op_set_version > 1: + node = helper.make_node( + 'TopK', + inputs=['X', 'K'], + outputs=['Values', 'Indices'], + axis=axis, # Axis along which to find the top K elements + ) + input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)) + else: + node = helper.make_node( + 'TopK', + inputs=['X'], + outputs=['Values', 'Indices'], + axis=axis, # Axis along which to find the top K elements + k=k + ) + + # Create the graph + graph = helper.make_graph( + nodes = [node], + name = 'TopKGraph', + inputs = input_tensors, + outputs = output_tensors, + # Unconmment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing. + #initializer = [ + # helper.make_tensor('X', TensorProto.FLOAT, X.shape, X), + # helper.make_tensor('K', TensorProto.INT64, [1], [k]), + #] + ) + + # Create the model + model = helper.make_model( + graph, + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)] + ) + # Check the model + onnx.checker.check_model(model) + + # Save the model to a file + onnx.save(model, f'top_k_opset_{op_set_version}.onnx') + print(f"Model saved to top_k_opset_{op_set_version}.onnx") + +def main(): + # Unconmment when initializers are supported. + for op_set_version in [1, 10, 11]: + create_model(op_set_version) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx b/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx new file mode 100644 index 0000000000..4eb08a05c9 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/top_k/top_k_opset_1.onnx differ diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index a1c9103b41..69fc8f9712 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -11,7 +11,8 @@ use super::{ max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, top_k::TopKNode, unary::UnaryNode, + unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -114,6 +115,7 @@ pub enum Node { Squeeze(SqueezeNode), Sum(SumNode), Tile(TileNode), + TopK(TopKNode), Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), @@ -162,6 +164,7 @@ macro_rules! match_all { Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), Node::Tile(node) => $func(node), + Node::TopK(node) => $func(node), Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), @@ -218,6 +221,7 @@ impl Node { Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", Node::Tile(_) => "tile", + Node::TopK(_) => "top_k", Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index ee294ddfd7..605cd9d8f2 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod slice; pub(crate) mod squeeze; pub(crate) mod sum; pub(crate) mod tile; +pub(crate) mod top_k; pub(crate) mod unary; pub(crate) mod unsqueeze; pub(crate) use base::*; diff --git a/crates/burn-import/src/burn/node/top_k.rs b/crates/burn-import/src/burn/node/top_k.rs new file mode 100644 index 0000000000..4fc899dcf7 --- /dev/null +++ b/crates/burn-import/src/burn/node/top_k.rs @@ -0,0 +1,114 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; + +#[derive(Config, Debug)] +pub struct TopKConfig { + pub axis: usize, + pub k: usize, + pub largest: usize, +} + +#[derive(Debug, Clone, new)] +pub struct TopKNode { + pub input: TensorType, + pub outputs: Vec, + pub config: TopKConfig, +} + +impl NodeCodegen for TopKNode { + fn output_types(&self) -> Vec { + self.outputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let axis = self.config.axis.to_token_stream(); + let k = self.config.k.to_token_stream(); + let largest = self.config.largest.to_token_stream(); + + let input = scope.tensor_use_owned(&self.input, node_position); + let values_output = &self.outputs[0].name; + let indices_output = &self.outputs[1].name; + + quote! { + let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis, #largest); + } + } + + fn into_node(self) -> Node { + Node::TopK(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{test::assert_tokens, top_k::TopKNode}, + TensorType, + }; + + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + let config = TopKConfig::new(1, 3, 1); + + graph.register(TopKNode::new( + TensorType::new_float("input_tensor", 4), + vec![ + TensorType::new_float("values_tensor", 4), + TensorType::new_int("indices_tensor", 4), + ], + config, + )); + + graph.register_input_output( + vec!["input_tensor".to_string()], + vec!["values_tensor".to_string(), "indices_tensor".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input_tensor: Tensor) -> (Tensor, Tensor) { + let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize, 1usize); + (values_tensor, indices_tensor) + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 4621b129d6..bda0d745bf 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,9 +7,16 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig}; +use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig}; use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node}; +/// Extract and convert a given attribute to i64 +fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { + let error_msg = format!("Expected the following attribute key: {:?}", key); + let value = node.attrs.get(key).expect(&error_msg).clone().into_i64(); + value +} + /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec @@ -795,6 +802,38 @@ pub fn tile_config(node: &Node) -> TileConfig { TileConfig::new(repeat) } +/// Create a TopKConfig from the attributes of the node. +pub fn top_k_config(node: &Node) -> TopKConfig { + // extract the shape of the input data tensor + let data_tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let k = match node.inputs.get(1) { + Some(k_tensor) => k_tensor + .clone() + .value + .expect("Expecting K tensor to have a value.") + .into_i64s()[0], + _ => extract_attr_value_i64(node, "k"), + }; + + let mut axis: i64 = extract_attr_value_i64(node, "axis"); + + // if axis is negative, it is counted from the end + if axis < 0 { + axis += data_tensor.dim as i64; + } + + let largest = match node.attrs.get("largest") { + Some(val) => val.clone().into_i64(), + _ => 1, + }; + + TopKConfig::new(axis as usize, k as usize, largest as usize) +} + /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads_input(node: &Node) -> Vec { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 5c4f34078f..70ef970337 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -51,6 +51,7 @@ use crate::{ squeeze::SqueezeNode, sum::SumNode, tile::TileNode, + top_k::TopKNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, @@ -67,8 +68,8 @@ use super::op_configuration::{ hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, - shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config, - unsqueeze_config, + shape_config, slice_config, softmax_config, squeeze_config, tile_config, top_k_config, + transpose_config, unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -338,6 +339,7 @@ impl ParsedOnnxGraph { NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)), NodeType::Tile => graph.register(Self::tile_conversion(node)), + NodeType::TopK => graph.register(Self::top_k_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) @@ -357,6 +359,7 @@ impl ParsedOnnxGraph { .iter() .map(|input| input.name.clone()) .collect::>(); + let output_names = self .0 .outputs @@ -1184,6 +1187,17 @@ impl ParsedOnnxGraph { TileNode::new(input, output, config) } + + fn top_k_conversion(node: Node) -> TopKNode { + // Inputs + let input = TensorType::from(node.inputs.first().unwrap()); + + // Outputs + let outputs = node.outputs.iter().map(TensorType::from).collect(); + let config = top_k_config(&node); + + TopKNode::new(input, outputs, config) + } } /// Extract data from node states and convert it to `TensorData`. diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index b67293715b..77057e2f6d 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -726,20 +726,53 @@ where } /// Returns the `k` largest elements of the given input tensor along a given dimension. - pub fn topk(self, k: usize, dim: usize) -> Tensor { + pub fn topk(self, k: usize, dim: usize, largest: Option) -> Tensor { let k_indices = Tensor::arange(0..k as i64, &self.device()); - self.sort_descending(dim).select(dim, k_indices) + match largest { + Some(largest) => { + if largest == 1 { + self.sort_descending(dim).select(dim, k_indices) + } else { + self.sort(dim).select(dim, k_indices) + } + } + _ => self.sort_descending(dim).select(dim, k_indices), + } } /// Returns the `k` largest elements of the given input tensor along a given dimension. /// Also returns the indices. - pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor, Tensor) { + pub fn topk_with_indices( + self, + k: usize, + dim: usize, + largest: Option, + ) -> (Tensor, Tensor) { let k_indices = Tensor::arange(0..k as i64, &self.device()); - let (values, indices) = self.sort_descending_with_indices(dim); - ( - values.select(dim, k_indices.clone()), - indices.select(dim, k_indices), - ) + match largest { + Some(largest) => { + if largest == 1 { + let (values, indices) = self.sort_descending_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } else { + let (values, indices) = self.sort_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } + } + _ => { + let (values, indices) = self.sort_descending_with_indices(dim); + ( + values.select(dim, k_indices.clone()), + indices.select(dim, k_indices), + ) + } + } } /// Pad the tensor of rank two or higher with the given value on the last two dimensions. diff --git a/crates/burn-tensor/src/tests/ops/topk.rs b/crates/burn-tensor/src/tests/ops/topk.rs index 9d98926655..59411f7bd8 100644 --- a/crates/burn-tensor/src/tests/ops/topk.rs +++ b/crates/burn-tensor/src/tests/ops/topk.rs @@ -6,48 +6,85 @@ mod tests { #[test] fn test_topk_1d() { // Int + // largest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let values = tensor.topk(3, /*dim*/ 0); + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&expected, false); + // smallest + let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); + + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1, 2, 3]); + + values.into_data().assert_eq(&expected, false); + // Float + // largest let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); - let values = tensor.topk(3, /*dim*/ 0); + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5., 4., 3.]); values.into_data().assert_approx_eq(&expected, 5); + // Float + // smallest + let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]); + + let values = tensor.topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1., 2., 3.]); + + values.into_data().assert_approx_eq(&expected, 1); } #[test] fn test_topk() { // 3D Int + // largest let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); - let values = tensor.topk(2, /*dim*/ 2); + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]); values.into_data().assert_eq(&expected, false); + // smallest + let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]); + + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]); + + values.into_data().assert_eq(&expected, false); + // 3D Float + // largest let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); - let values = tensor.topk(2, /*dim*/ 2); + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]); values.into_data().assert_approx_eq(&expected, 5); + + // smallest + let tensor = + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]); + + let values = tensor.topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1, 4], [2, 5]], [[0, 3], [2, 8]]]); } #[test] fn test_topk_with_indices() { // 1D + // largest let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); + let (values, indices) = + tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5, 4, 3]); values.into_data().assert_eq(&values_expected, false); @@ -55,11 +92,25 @@ mod tests { let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); + // smallest + let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]); + + let (values, indices) = + tensor.topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + + let values_expected = TensorData::from([1, 2, 3]); + values.into_data().assert_eq(&values_expected, false); + + let indices_expected = TensorData::from([0, 1, 2]); + indices.into_data().assert_eq(&indices_expected, false); + // 3D + // largest let tensor = TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); + let (values, indices) = + tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -68,5 +119,20 @@ mod tests { let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); + + // smallest + let tensor = + TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]); + + let (values, indices) = + tensor.topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + + let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + + values.into_data().assert_approx_eq(&values_expected, 5); + + let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); + + indices.into_data().assert_eq(&indices_expected, false); } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/topk.rs b/crates/burn-tensor/src/tests/quantization/ops/topk.rs index 7913fb79e1..78b4900523 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/topk.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/topk.rs @@ -14,9 +14,19 @@ mod tests { ); let tensor = TestTensor::<1>::from_data(data, &Default::default()); - let values = tensor.topk(3, /*dim*/ 0); + // largest + let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(1)); let expected = TensorData::from([5., 4., 3.]); + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 3); + + // smallest + let values = tensor.clone().topk(3, /*dim*/ 0, /*largest*/ Some(0)); + let expected = TensorData::from([1., 2., 3.]); + values .dequantize() .into_data() @@ -24,7 +34,7 @@ mod tests { } #[test] - fn test_topk() { + fn test_topk_3d() { // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] let data = TensorData::quantized( vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], @@ -33,9 +43,20 @@ mod tests { ); let tensor = TestTensor::<3>::from_data(data, &Default::default()); - let values = tensor.topk(2, /*dim*/ 2); + // largest + let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(1)); let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&expected, 1); + + // smallest + let values = tensor.clone().topk(2, /*dim*/ 2, /*largest*/ Some(0)); + let expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + // Precision 1 to approximate de/quantization errors values .dequantize() @@ -44,8 +65,7 @@ mod tests { } #[test] - fn test_topk_with_indices() { - // 1D + fn test_topk_with_indices_1d() { // Quantized [1.0, 2.0, 3.0, 4.0, 5.0] let data = TensorData::quantized( vec![-77i8, -26, 25, 76, 127], @@ -54,7 +74,11 @@ mod tests { ); let tensor = TestTensor::<1>::from_data(data, &Default::default()); - let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0); + // largest + let (values, indices) = + tensor + .clone() + .topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(1)); let values_expected = TensorData::from([5., 4., 3.]); values @@ -65,7 +89,24 @@ mod tests { let indices_expected = TensorData::from([4, 3, 2]); indices.into_data().assert_eq(&indices_expected, false); - // 3D + // smallest + let (values, indices) = + tensor + .clone() + .topk_with_indices(3, /*dim*/ 0, /*largest*/ Some(0)); + + let values_expected = TensorData::from([1., 2., 3.]); + values + .dequantize() + .into_data() + .assert_eq(&values_expected, false); + + let indices_expected = TensorData::from([0, 1, 2]); + indices.into_data().assert_eq(&indices_expected, false); + } + + #[test] + fn test_topk_with_indices_3d() { // Quantized [[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]] let data = TensorData::quantized( vec![-100i8, -15, 70, -71, 14, 42, -43, -128, 127, 99, -71, 70], @@ -74,7 +115,11 @@ mod tests { ); let tensor = TestTensor::<3>::from_data(data, &Default::default()); - let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2); + // largest + let (values, indices) = + tensor + .clone() + .topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(1)); let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]); @@ -87,5 +132,23 @@ mod tests { let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]); indices.into_data().assert_eq(&indices_expected, false); + + // smallest + let (values, indices) = + tensor + .clone() + .topk_with_indices(2, /*dim*/ 2, /*largest*/ Some(0)); + + let values_expected = TensorData::from([[[1., 4.], [2., 5.]], [[0., 3.], [2., 7.]]]); + + // Precision 1 to approximate de/quantization errors + values + .dequantize() + .into_data() + .assert_approx_eq(&values_expected, 1); + + let indices_expected = TensorData::from([[[0, 1], [0, 1]], [[1, 0], [1, 2]]]); + + indices.into_data().assert_eq(&indices_expected, false); } } diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ff580b37aa..f093421abc 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -82,6 +82,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Sub => same_as_input_broadcast(node), NodeType::Sum => same_as_input_broadcast(node), NodeType::Tanh => same_as_input(node), + NodeType::TopK => top_k_update_output(node), NodeType::Transpose => same_as_input(node), NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Where => where_update_outputs(node), @@ -477,6 +478,35 @@ fn same_as_input(node: &mut Node) { node.outputs[0].ty = node.inputs[0].ty.clone(); } +fn top_k_update_output(node: &mut Node) { + let dim = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.dim, + _ => panic!("TopK: invalid input type"), + }; + + let output_values_elem = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.elem_type.clone(), + _ => panic!("TopK: invalid output type"), + }; + + let output_indices_elem = match &node.outputs[1].ty { + ArgType::Tensor(_) => ElementType::Int64, + _ => panic!("TopK: invalid output type"), + }; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim, + shape: None, // shape is tracked and calculated at runtime + elem_type: output_values_elem, + }); + + node.outputs[1].ty = ArgType::Tensor(TensorType { + dim, + shape: None, // shape is tracked and calculated at runtime + elem_type: output_indices_elem, + }); +} + /// Temporary pass-through stub for dimension inference so that we can export the IR model. fn temporary_pass_through_stub(node: &mut Node) { log::warn!("Must implement dimension inference for {:?}", node); diff --git a/crates/onnx-ir/src/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs index fa30bcf83c..470ab2ddca 100644 --- a/crates/onnx-ir/src/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -151,7 +151,7 @@ impl GraphData { for output in node.outputs.iter_mut() { self.input_name_map.insert( output.name.clone(), - IOEntry::Node(self.processed_nodes.len(), 0), + IOEntry::Node(self.processed_nodes.len(), out_count - 1), ); output.name = format!("{}_out{}", node.name, out_count); out_count += 1; @@ -239,7 +239,6 @@ impl OnnxGraphBuilder { // TODO Update graph inputs and outputs to match the processed nodes inputs and outputs // This is necessary for the graph to be valid // ConstantOfShape updates input to be Shape argument and output Tensor dim is updated - OnnxGraph { nodes: processed_nodes, inputs,