Skip to content

Commit

Permalink
update reduce_logsumexp
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng-Bicheng committed May 16, 2024
1 parent b8c1341 commit 38f4193
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 367 deletions.
57 changes: 44 additions & 13 deletions paddle2onnx/mapper/tensor/reduce_logsumexp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,74 @@ int32_t ReduceLogSumExpMapper::GetMinOpset(bool verbose) {
return op_version;
}

void ReduceLogSumExpMapper::Opset7() {
void ReduceLogSumExpMapper::Opset18() {
GetAttr("keepdim", &keep_dim_);
GetAttr("reduce_all", &reduce_all_);
GetAttr("axis", &dim_);

auto x_info = GetInput("X");
auto out_info = GetOutput("Out");
std::string axis_name = "axis";
if (IsAttrVar(axis_name)) {
auto info = GetAttrVar(axis_name);
TryGetValue(info[0], &dim_);
std::string dims;
if (!reduce_all_) {
dims = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, dim_);
} else {
GetAttr(axis_name, &dim_);
dims = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, Arange(0, x_info[0].Rank()));
}

std::string input_name = x_info[0].name;
auto input_tpye = x_info[0].dtype;
if (x_info[0].dtype == P2ODataType::BOOL) {
input_name = helper_->AutoCast(input_name, input_tpye, P2ODataType::INT32);
input_tpye = P2ODataType::INT32;
}
auto reduce_node = helper_->MakeNode("ReduceLogSumExp", {input_name, dims});

// Add attribute
AddAttribute(reduce_node, "keepdims", static_cast<int64_t>(keep_dim_));
auto out_node_name = reduce_node->output(0);

bool reduce_all_axes = dim_.size() == x_info[0].Rank();
if (reduce_all_) {
reduce_all_axes = true;
}
if (!keep_dim_ && reduce_all_axes) {
out_node_name = helper_->Reshape(out_node_name, {-1});
}
auto out_info = GetOutput("Out");
helper_->AutoCast(out_node_name, out_info[0].name, input_tpye, out_info[0].dtype);
}

void ReduceLogSumExpMapper::Opset11() {
GetAttr("keepdim", &keep_dim_);
GetAttr("reduce_all", &reduce_all_);
GetAttr("axis", &dim_);

auto x_info = GetInput("X");
auto out_info = GetOutput("Out");
std::string input_name = x_info[0].name;
if (OpType() == "reduce_prod" && x_info[0].dtype == P2ODataType::FP64) {
input_name = helper_->AutoCast(x_info[0].name, P2ODataType::FP64, P2ODataType::FP32);
auto input_tpye = x_info[0].dtype;
if (x_info[0].dtype == P2ODataType::BOOL) {
input_name = helper_->AutoCast(input_name, input_tpye, P2ODataType::INT32);
input_tpye = P2ODataType::INT32;
}
auto reduce_node = helper_->MakeNode("ReduceLogSumExp", {input_name});


// Add attribute
if (!reduce_all_) {
AddAttribute(reduce_node, "axes", dim_);
} else {
AddAttribute(reduce_node, "axes", Arange(0, x_info[0].Rank()));
}
AddAttribute(reduce_node, "keepdims", static_cast<int64_t>(keep_dim_));

auto out = reduce_node->output(0);
if (OpType() == "reduce_prod" && x_info[0].dtype == P2ODataType::FP64) {
out = helper_->AutoCast(reduce_node->output(0), P2ODataType::FP32, P2ODataType::FP64);
bool reduce_all_axes = dim_.size() == x_info[0].Rank();
if (reduce_all_) {
reduce_all_axes = true;
}
if (!keep_dim_ && reduce_all_axes) {
out = helper_->Reshape(out, {-1});
}
helper_->AutoCast(out, out_info[0].name, x_info[0].dtype, out_info[0].dtype);
helper_->AutoCast(out, out_info[0].name, input_tpye, out_info[0].dtype);
}

} // namespace paddle2onnx
13 changes: 3 additions & 10 deletions paddle2onnx/mapper/tensor/reduce_logsumexp.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,10 @@ class ReduceLogSumExpMapper : public Mapper {
ReduceLogSumExpMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {
if (OpType() == "logsumexp") {
GetAttr("keepdim", &keep_dim_);
GetAttr("reduce_all", &reduce_all_);
} else {
GetAttr("keep_dim", &keep_dim_);
GetAttr("reduce_all", &reduce_all_);
GetAttr("in_dtype", &in_dtype_);
GetAttr("out_dtype", &out_dtype_);
}
}
void Opset7();

void Opset18() override;
void Opset11();

int32_t GetMinOpset(bool verbose = false);

Expand Down
114 changes: 0 additions & 114 deletions tests/test_all.py

This file was deleted.

110 changes: 0 additions & 110 deletions tests/test_any.py

This file was deleted.

1 change: 1 addition & 0 deletions tests/test_auto_scan_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def sample_convert_config(self, draw):
input_shape[4] = 10
axis = draw(st.integers(min_value=1, max_value=len(input_shape) - 1))

# axis_type = draw(st.sampled_from(["int", "list"]))
axis_type = draw(st.sampled_from(["int", "list"]))
if axis_type == "int":
normalized_shape = input_shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auto_scan_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def sample_convert_config(self, draw):
"op_names": ["logsumexp"],
"test_data_shapes": [input_shape],
"test_data_types": [[dtype]],
"opset_version": [7, 9, 15],
"opset_version": [11, 13, 18],
"input_spec_shape": [],
"axis": axis,
"keepdim": keepdim,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_auto_scan_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import random

op_api_map = {
# "reduce_max": paddle.max,
"reduce_max": paddle.max,
"reduce_min": paddle.min,
# "reduce_mean": paddle.mean,
# "reduce_sum": paddle.sum,
# "reduce_prod": paddle.prod,
"reduce_mean": paddle.mean,
"reduce_sum": paddle.sum,
"reduce_prod": paddle.prod,
}

opset_version_map = {
Expand Down
Loading

0 comments on commit 38f4193

Please sign in to comment.