-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Onnx op topk #2305
base: main
Are you sure you want to change the base?
Onnx op topk #2305
Changes from all commits
da34ba5
2973945
578c346
17dce60
ce1aa89
5783f76
16ce366
4ca56c6
49a6a73
f293ad3
449f519
249d60c
07f79ab
7e221c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import numpy as np | ||
import onnx | ||
from onnx import helper, TensorProto | ||
|
||
# Define the input tensor | ||
X = np.array([[0, 1, 2, 3], | ||
[4, 5, 6, 7], | ||
[8, 9, 10, 11]], dtype=np.float32) | ||
|
||
# Define the value of K | ||
k = 3 | ||
K = np.array([k], dtype=np.int64) | ||
axis = 1 | ||
new_dims = [X.shape[0], k] | ||
|
||
def create_model(op_set_version: int): | ||
input_tensors = [helper.make_tensor_value_info('X', TensorProto.FLOAT, X.shape)] | ||
|
||
output_tensors = [ | ||
helper.make_tensor_value_info('Values', TensorProto.FLOAT, new_dims), | ||
helper.make_tensor_value_info('Indices', TensorProto.INT32, new_dims) | ||
] | ||
|
||
# Create the TopK node | ||
if op_set_version > 1: | ||
node = helper.make_node( | ||
'TopK', | ||
inputs=['X', 'K'], | ||
outputs=['Values', 'Indices'], | ||
axis=axis, # Axis along which to find the top K elements | ||
) | ||
input_tensors.append(helper.make_tensor_value_info('K', TensorProto.INT32, K.shape)) | ||
else: | ||
node = helper.make_node( | ||
'TopK', | ||
inputs=['X'], | ||
outputs=['Values', 'Indices'], | ||
axis=axis, # Axis along which to find the top K elements | ||
k=k | ||
) | ||
|
||
# Create the graph | ||
graph = helper.make_graph( | ||
nodes = [node], | ||
name = 'TopKGraph', | ||
inputs = input_tensors, | ||
outputs = output_tensors, | ||
# Unconmment when initializers are supported. Currently we can't test opset 10/11 since the code will require a k value to be initialized for testing. | ||
#initializer = [ | ||
# helper.make_tensor('X', TensorProto.FLOAT, X.shape, X), | ||
# helper.make_tensor('K', TensorProto.INT64, [1], [k]), | ||
#] | ||
) | ||
|
||
# Create the model | ||
model = helper.make_model( | ||
graph, | ||
ir_version=8, | ||
opset_imports=[onnx.helper.make_operatorsetid("", op_set_version)] | ||
) | ||
# Check the model | ||
onnx.checker.check_model(model) | ||
|
||
# Save the model to a file | ||
onnx.save(model, f'top_k_opset_{op_set_version}.onnx') | ||
print(f"Model saved to top_k_opset_{op_set_version}.onnx") | ||
|
||
def main(): | ||
# Unconmment when initializers are supported. | ||
for op_set_version in [1, 10, 11]: | ||
create_model(op_set_version) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
use super::{Node, NodeCodegen}; | ||
use crate::burn::{Scope, TensorType, Type}; | ||
use burn::config::Config; | ||
use burn::record::PrecisionSettings; | ||
use proc_macro2::TokenStream; | ||
use quote::{quote, ToTokens}; | ||
|
||
#[derive(Config, Debug)] | ||
pub struct TopKConfig { | ||
pub axis: usize, | ||
pub k: usize, | ||
pub largest: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, new)] | ||
pub struct TopKNode { | ||
pub input: TensorType, | ||
pub outputs: Vec<TensorType>, | ||
pub config: TopKConfig, | ||
} | ||
|
||
impl<PS: PrecisionSettings> NodeCodegen<PS> for TopKNode { | ||
fn output_types(&self) -> Vec<Type> { | ||
self.outputs | ||
.iter() | ||
.map(|t| Type::Tensor(t.clone())) | ||
.collect() | ||
} | ||
|
||
fn input_types(&self) -> Vec<Type> { | ||
vec![Type::Tensor(self.input.clone())] | ||
} | ||
|
||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { | ||
let axis = self.config.axis.to_token_stream(); | ||
let k = self.config.k.to_token_stream(); | ||
let largest = self.config.largest.to_token_stream(); | ||
|
||
let input = scope.tensor_use_owned(&self.input, node_position); | ||
let values_output = &self.outputs[0].name; | ||
let indices_output = &self.outputs[1].name; | ||
|
||
quote! { | ||
let (#values_output, #indices_output) = #input.topk_with_indices(#k, #axis, #largest); | ||
} | ||
} | ||
|
||
fn into_node(self) -> Node<PS> { | ||
Node::TopK(self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use burn::record::FullPrecisionSettings; | ||
|
||
use super::*; | ||
use crate::burn::{ | ||
graph::BurnGraph, | ||
node::{test::assert_tokens, top_k::TopKNode}, | ||
TensorType, | ||
}; | ||
|
||
#[test] | ||
fn test_codegen_nodes() { | ||
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); | ||
let config = TopKConfig::new(1, 3, 1); | ||
|
||
graph.register(TopKNode::new( | ||
TensorType::new_float("input_tensor", 4), | ||
vec![ | ||
TensorType::new_float("values_tensor", 4), | ||
TensorType::new_int("indices_tensor", 4), | ||
], | ||
config, | ||
)); | ||
|
||
graph.register_input_output( | ||
vec!["input_tensor".to_string()], | ||
vec!["values_tensor".to_string(), "indices_tensor".to_string()], | ||
); | ||
|
||
let expected = quote! { | ||
use burn::tensor::Int; | ||
use burn::{ | ||
module::Module, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Model<B: Backend> { | ||
phantom: core::marker::PhantomData<B>, | ||
device: burn::module::Ignored<B::Device>, | ||
} | ||
|
||
impl<B: Backend> Model <B> { | ||
#[allow(unused_variables)] | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
phantom: core::marker::PhantomData, | ||
device: burn::module::Ignored(device.clone()), | ||
} | ||
} | ||
#[allow(clippy::let_and_return, clippy::approx_constant)] | ||
pub fn forward(&self, input_tensor: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4, Int>) { | ||
let (values_tensor, indices_tensor) = input_tensor.topk_with_indices(3usize, 1usize, 1usize); | ||
(values_tensor, indices_tensor) | ||
} | ||
} | ||
}; | ||
|
||
assert_tokens(graph.codegen(), expected); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,16 @@ use burn::nn::{ | |
PaddingConfig2d, PaddingConfig3d, | ||
}; | ||
|
||
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig}; | ||
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig, top_k::TopKConfig}; | ||
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node}; | ||
|
||
/// Extract and convert a given attribute to i64 | ||
fn extract_attr_value_i64(node: &Node, key: &str) -> i64 { | ||
let error_msg = format!("Expected the following attribute key: {:?}", key); | ||
let value = node.attrs.get(key).expect(&error_msg).clone().into_i64(); | ||
value | ||
} | ||
|
||
/// Create a Conv1dConfig from the attributes of the node | ||
pub fn conv1d_config(curr: &Node) -> Conv1dConfig { | ||
let mut kernel_shape = Vec::new(); // TODO default inferred from weight tensor per spec | ||
|
@@ -795,6 +802,38 @@ pub fn tile_config(node: &Node) -> TileConfig { | |
TileConfig::new(repeat) | ||
} | ||
|
||
/// Create a TopKConfig from the attributes of the node. | ||
pub fn top_k_config(node: &Node) -> TopKConfig { | ||
// extract the shape of the input data tensor | ||
let data_tensor = match node.inputs.first().unwrap().clone().ty { | ||
ArgType::Tensor(tensor) => tensor, | ||
_ => panic!("Only tensor input is valid"), | ||
}; | ||
|
||
let k = match node.inputs.get(1) { | ||
Some(k_tensor) => k_tensor | ||
.clone() | ||
.value | ||
.expect("Expecting K tensor to have a value.") | ||
.into_i64s()[0], | ||
_ => extract_attr_value_i64(node, "k"), | ||
}; | ||
|
||
let mut axis: i64 = extract_attr_value_i64(node, "axis"); | ||
|
||
// if axis is negative, it is counted from the end | ||
if axis < 0 { | ||
axis += data_tensor.dim as i64; | ||
} | ||
|
||
let largest = match node.attrs.get("largest") { | ||
Some(val) => val.clone().into_i64(), | ||
_ => 1, | ||
}; | ||
Comment on lines
+829
to
+832
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're not checking for "k" as the second input of the node (for opsets 10, 11) and just adding support for opset 1, then we don't need to check for the "largest" attribute here. It's only present in the later version 11 of the op. So we can remove this from the config and node. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure thing will remove and resubmit tonight |
||
|
||
TopKConfig::new(axis as usize, k as usize, largest as usize) | ||
} | ||
|
||
/// Create a PadConfig from the attributes of the node | ||
pub fn pad_config(node: &Node) -> PadConfig { | ||
fn get_pads_input(node: &Node) -> Vec<i64> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks the CI caught something I missed! This file doesn't exist anymore with your changes :)