Skip to content

Commit

Permalink
fix: 修改to_onnx脚本,能够正确导出llama模型
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz committed Mar 12, 2024
1 parent bf5402e commit 93dc8d9
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 8 deletions.
66 changes: 65 additions & 1 deletion scripts/onnx/to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import mmap
import re
import argparse
from onnx import TensorProto, NodeProto, save_model
from onnx.helper import (
Expand Down Expand Up @@ -139,7 +140,9 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[],
)
if self.type == "Add":
if self.type in ["Add", "Pow", "Sqrt", "Div",
"Mul", "Sub", "Exp", "Log",
"Neg", "Sigmoid"]:
return (
make_node(
self.type,
Expand Down Expand Up @@ -203,6 +206,58 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]:
),
[shape],
)
if self.type in ["Gather", "Concat", "Softmax"]:
meta = self.meta.split(b"/")
axis = int(meta[0])
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[self.topo.outputs[0]].name],
self.name,
axis = axis,
),
[],
)
if self.type == "ReduceMean":
meta = self.meta.split(b",")
keepDims = meta[2] == b"true"
axes = [int(x) for x in split_array(meta[0])]
return (
make_node(
self.type,
[tensors[self.topo.inputs[0]].name],
[tensors[self.topo.outputs[0]].name],
self.name,
axes = axes,
keepdims = keepDims,
),
[],
)
if self.type =="Transpose":
meta = [int(x) for x in split_array(self.meta)]
return (
make_node(
self.type,
[tensors[self.topo.inputs[0]].name],
[tensors[self.topo.outputs[0]].name],
self.name,
perm=meta,
),
[],
)
if self.type == "Slice":
# starts, ends, axes, steps = split_array_slice(self.meta)
return (
make_node(
self.type,
[tensors[i].name for i in self.topo.inputs],
[tensors[self.topo.outputs[0]].name],
self.name,
),
[],
)

raise ValueError(f"Unsupported operator {self.type}")

def parse_args():
Expand All @@ -221,6 +276,15 @@ def parse_args():
def split_array(arr: bytes):
return (x for x in arr.strip().strip(b"[").strip(b"]").split())

def split_array_slice(arr: bytes):
meta_array = split_array(arr)
meta = [list(map(int, re.findall(r'\d+', x))) for x in meta_array]
starts = [int(x[0]) for x in meta]
ends = [int(x[0] + x[1] * x[2]) for x in meta]
axes = [x for x in range(len(meta))]
steps = [int(x[2]) for x in meta]
return starts, ends, axes, steps

def main():
path = parse_args()
info_path = path + "/graph.info"
Expand Down
4 changes: 2 additions & 2 deletions src/05computation/src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ namespace refactor::computation {
void Graph::optimize() {
auto graphMutant = GraphMutant(*this);
std::vector<std::string_view> passes = {
"MatMulTransposeFuse",
"ConvToMatmul",
// "MatMulTransposeFuse",
// "ConvToMatmul",
};
register_();//all pass insert
auto g = std::make_shared<GraphMutant>(graphMutant);
Expand Down
2 changes: 1 addition & 1 deletion src/05computation/src/operators/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace refactor::computation {
return std::make_unique<kernel::ReduceCollector>(target, type, axes);
}
auto Op::serialize() const noexcept -> std::string {
return fmt::format("{}({}/{}, {})",
return fmt::format("{}({}, {}, {})",
name(),
vec2str(axes),
rank,
Expand Down
9 changes: 7 additions & 2 deletions src/07onnx/src/operators/cast.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "computation/operators/cast.h"
#include "cast.hh"
#include "common.h"
#include "computation/operators/identity.h"
#include <execution>

namespace refactor::onnx {
Expand Down Expand Up @@ -30,7 +31,6 @@ namespace refactor::onnx {

auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
EXPECT_SIZE(1)

auto const &input = inputs[0];
auto ans = Tensor::share(to, input.shape, extractDependency(inputs));
if (!options.shouldCalculate(inputs, {*ans})) {
Expand Down Expand Up @@ -116,8 +116,13 @@ namespace refactor::onnx {
}
return Ok(Tensors{std::move(ans)});
}
auto Op::lower(TensorRefs) const -> computation::OpBox {
auto Op::lower(TensorRefs inputs) const -> computation::OpBox {
using Op_ = computation::Cast;
auto const &input = inputs[0];
auto from = input.dataType;
if (from == to) {
return std::make_unique<computation::Identity>();
}
return std::make_unique<Op_>();
}

Expand Down
2 changes: 1 addition & 1 deletion src/07onnx/src/operators/mat_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace refactor::onnx {

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::MatMul;
return std::make_unique<Op_>(1.0, 1.0, false, false);
return std::make_unique<Op_>(1.0, 0.0, false, false);
}

}// namespace refactor::onnx
2 changes: 1 addition & 1 deletion src/09python_ffi/src/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ namespace refactor::python_ffi {
msg += ']';
RUNTIME_ERROR(std::move(msg));
}
_g.fillEdgeInfo(false);
_g.fillEdgeInfo(true);

namespace fs = std::filesystem;
auto path = fs::path(std::move(path_));
Expand Down

0 comments on commit 93dc8d9

Please sign in to comment.