diff --git a/include/nncase/compiler.h b/include/nncase/compiler.h index 67485fa5f..ad30e6f2e 100644 --- a/include/nncase/compiler.h +++ b/include/nncase/compiler.h @@ -108,6 +108,7 @@ class NNCASE_API compiler virtual void import_tflite(std::span model, const import_options &options) = 0; virtual void import_onnx(std::span model, const import_options &options) = 0; virtual void import_caffe(std::span model, std::span prototxt) = 0; + virtual void import_pnnx(std::string parampath, std::string binpath, const import_options &options) = 0; virtual void use_ptq(ptq_dataset_options options) = 0; virtual void use_ptq(ptq_tensor_options options) = 0; virtual void dump_range_options(dump_range_dataset_options options) = 0; diff --git a/include/nncase/importer/importer.h b/include/nncase/importer/importer.h index 92f4d7b1e..8c51a9e54 100644 --- a/include/nncase/importer/importer.h +++ b/include/nncase/importer/importer.h @@ -29,4 +29,5 @@ struct import_options void import_tflite(ir::graph &graph, std::span model, const import_options &options, std::string &real_inlayout, std::string &real_outlayout); void import_onnx(ir::graph &graph, std::span model, const import_options &options, std::string &real_inlayout, std::string &real_outlayout); void import_caffe(ir::graph &graph, std::span model, std::span prototxt, std::string &real_inlayout, std::string &real_outlayout); +void import_pnnx(ir::graph &graph, std::string parampath, std::string binpath, const import_options &options, std::string &real_inlayout, std::string &real_outlayout); } diff --git a/python/nncase/native/ffi.cpp b/python/nncase/native/ffi.cpp index c4277a071..a720cbe9d 100644 --- a/python/nncase/native/ffi.cpp +++ b/python/nncase/native/ffi.cpp @@ -196,6 +196,7 @@ PYBIND11_MODULE(_nncase, m) .def("import_tflite", &compiler::import_tflite) .def("import_onnx", &compiler::import_onnx) .def("import_caffe", &compiler::import_caffe) + .def("import_pnnx", &compiler::import_pnnx) .def("compile", &compiler::compile) .def("use_ptq", py::overload_cast(&compiler::use_ptq)) .def("dump_range_options", py::overload_cast(&compiler::dump_range_options)) diff --git a/src/cli/compile.cpp b/src/cli/compile.cpp index 294c56d4f..a1249cc2a 100644 --- a/src/cli/compile.cpp +++ b/src/cli/compile.cpp @@ -22,7 +22,7 @@ using namespace nncase::cli; compile_command::compile_command(lyra::cli &cli) { cli.add_argument(lyra::command("compile", [this](const lyra::group &) { this->run(); }) - .add_argument(lyra::opt(input_format_, "input format").name("-i").name("--input-format").required().help("input format, e.g. tflite|onnx|caffe")) + .add_argument(lyra::opt(input_format_, "input format").name("-i").name("--input-format").required().help("input format, e.g. tflite|onnx|caffe|pnnx")) .add_argument(lyra::opt(target_name_, "target").name("-t").name("--target").required().help("target architecture, e.g. cpu|k210|k510")) .add_argument(lyra::arg(input_filename_, "input file").required().help("input file")) .add_argument(lyra::opt(input_prototxt_, "input prototxt").name("--input-prototxt").optional().help("input prototxt")) @@ -153,6 +153,12 @@ void compile_command::run() auto input_prototxt = read_file(input_prototxt_); compiler->import_caffe(file_data, input_prototxt); } + else if (input_format_ == "pnnx") + { + std::filesystem::path input_bin_filename_ = input_filename_; + input_bin_filename_.replace_extension("bin"); + compiler->import_pnnx(input_filename_, input_bin_filename_.string(), i_options); + } else { throw std::invalid_argument("Invalid input format: " + input_format_); diff --git a/src/importer/CMakeLists.txt b/src/importer/CMakeLists.txt index 7c0d66338..bbef82c7d 100644 --- a/src/importer/CMakeLists.txt +++ b/src/importer/CMakeLists.txt @@ -3,9 +3,10 @@ add_subdirectory(tflite) add_subdirectory(onnx) add_subdirectory(caffe) +add_subdirectory(pnnx) add_library(importer OBJECT importer.cpp) target_include_directories(importer PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include) -target_link_libraries(importer PUBLIC tflite_importer onnx_proto onnx_importer caffe_importer) -set_target_properties(importer PROPERTIES POSITION_INDEPENDENT_CODE ON) \ No newline at end of file +target_link_libraries(importer PUBLIC tflite_importer onnx_proto onnx_importer caffe_importer pnnx_importer) +set_target_properties(importer PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/importer/importer.cpp b/src/importer/importer.cpp index 9b5c619f5..a3a4de346 100644 --- a/src/importer/importer.cpp +++ b/src/importer/importer.cpp @@ -14,6 +14,7 @@ */ #include "caffe/caffe_importer.h" #include "onnx/onnx_importer.h" +#include "pnnx/pnnx_importer.h" #include "tflite/tflite_importer.h" #include @@ -34,4 +35,9 @@ void nncase::importer::import_onnx(ir::graph &graph, std::span mo void nncase::importer::import_caffe(ir::graph &graph, std::span model, std::span prototxt, std::string &real_inlayout, std::string &real_outlayout) { caffe_importer(model, prototxt, graph).import(real_inlayout, real_outlayout); -} \ No newline at end of file +} + +void nncase::importer::import_pnnx(ir::graph &graph, std::string parampath, std::string binpath, const import_options &options, std::string &real_inlayout, std::string &real_outlayout) +{ + pnnx_importer(parampath, binpath, graph).import(options, real_inlayout, real_outlayout); +} diff --git a/src/importer/pnnx/CMakeLists.txt b/src/importer/pnnx/CMakeLists.txt new file mode 100644 index 000000000..114c47a3e --- /dev/null +++ b/src/importer/pnnx/CMakeLists.txt @@ -0,0 +1,26 @@ + include(TestBigEndian) +test_big_endian(BIG_ENDIAN) + +set(PNNX_IMPORTER_SOURCES + pnnx_importer.cpp + ir.cpp + storezip.cpp + ) + +set(PNNX_IMPORTER_OPS_SOURCES + ops/input.cpp + ops/output.cpp + ops/relu.cpp + ops/relu6.cpp + ops/conv2d.cpp + ) + +add_library(pnnx_importer ${PNNX_IMPORTER_SOURCES} ${PNNX_IMPORTER_OPS_SOURCES}) + +target_compile_definitions(pnnx_importer PRIVATE NATIVE_IS_BIG_ENDIAN=${BIG_ENDIAN}) +target_include_directories(pnnx_importer PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) +get_filename_component(PARENT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) +target_include_directories(pnnx_importer PUBLIC ${PARENT_SOURCE_DIR}/include) + +target_link_libraries(pnnx_importer PUBLIC ir) +set_target_properties(pnnx_importer PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/importer/pnnx/ir.cpp b/src/importer/pnnx/ir.cpp new file mode 100644 index 000000000..263d81d53 --- /dev/null +++ b/src/importer/pnnx/ir.cpp @@ -0,0 +1,1767 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "storezip.h" + +namespace pnnx +{ + +static const char *type_to_string(int type) +{ + if (type == 1) + return "f32"; + if (type == 2) + return "f64"; + if (type == 3) + return "f16"; + if (type == 4) + return "i32"; + if (type == 5) + return "i64"; + if (type == 6) + return "i16"; + if (type == 7) + return "i8"; + if (type == 8) + return "u8"; + return "null"; +} + +static const char *type_to_numpy_string(int type) +{ + if (type == 1) + return "float32"; + if (type == 2) + return "float64"; + if (type == 3) + return "float16"; + if (type == 4) + return "int32"; + if (type == 5) + return "int64"; + if (type == 6) + return "int16"; + if (type == 7) + return "int8"; + if (type == 8) + return "uint8"; + return "null"; +} + +static size_t type_to_elemsize(int type) +{ + if (type == 1) + return 4; + if (type == 2) + return 8; + if (type == 3) + return 2; + if (type == 4) + return 4; + if (type == 5) + return 8; + if (type == 6) + return 2; + if (type == 7) + return 1; + if (type == 8) + return 1; + return 0; // null +} + +static int string_to_type(const char *s) +{ + if (strcmp(s, "f32") == 0) + return 1; + if (strcmp(s, "f64") == 0) + return 2; + if (strcmp(s, "f16") == 0) + return 3; + if (strcmp(s, "i32") == 0) + return 4; + if (strcmp(s, "i64") == 0) + return 5; + if (strcmp(s, "i16") == 0) + return 6; + if (strcmp(s, "i8") == 0) + return 7; + if (strcmp(s, "u8") == 0) + return 8; + return 0; // null +} + +Attribute::Attribute(const std::initializer_list &_shape, const std::vector &t) +{ + type = 1; + shape = _shape; + + if (shape.size() > 0) + { + int size = shape[0]; + for (size_t i = 1; i < shape.size(); i++) + { + size *= shape[i]; + } + + data.resize(size * type_to_elemsize(type)); + memcpy((void *)data.data(), (const void *)t.data(), data.size()); + } +} + +Parameter Parameter::parse_from_string(const std::string &value) +{ + Parameter p; + p.type = 0; + + if (value == "None" || value == "()" || value == "[]") + { + return p; + } + + if (value == "True" || value == "False") + { + // bool + p.type = 1; + p.b = value == "True"; + return p; + } + + if (value[0] == '(' || value[0] == '[') + { + // list + std::string lc = value.substr(1, value.size() - 2); + std::istringstream lcss(lc); + + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) + { + // string + p.type = 7; + p.as.push_back(elem); + } + else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) + { + // float + p.type = 6; + p.af.push_back(std::stof(elem)); + } + else + { + // integer + p.type = 5; + p.ai.push_back(std::stoi(elem)); + } + } + return p; + } + + if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) + { + // string + p.type = 4; + p.s = value; + return p; + } + + if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) + { + // float + p.type = 3; + p.f = std::stof(value); + return p; + } + + // integer + p.type = 2; + p.i = std::stoi(value); + return p; +} + +Graph::Graph() +{ +} + +Graph::~Graph() +{ + for (auto x : ops) + delete x; + + for (auto x : operands) + delete x; + + ops.clear(); + operands.clear(); +} + +Graph::Graph(const Graph & /*rhs*/) +{ +} + +Graph &Graph::operator=(const Graph & /*rhs*/) +{ + return *this; +} + +static void load_parameter(Operator *op, const std::string &key, const std::string &value) +{ + op->params[key] = Parameter::parse_from_string(value); +} + +static void load_input_key(Operator *op, const std::string &key, const std::string &value) +{ + op->inputnames.resize(op->inputs.size()); + + for (size_t i = 0; i < op->inputs.size(); i++) + { + const Operand *oprand = op->inputs[i]; + if (oprand->name == value) + { + op->inputnames[i] = key; + break; + } + } +} + +static void load_shape(Operator *op, const std::string &key, const std::string &value) +{ + Operand *operand = 0; + for (auto r : op->inputs) + { + if (r->name == key) + { + operand = r; + break; + } + } + + if (!operand) + { + for (auto r : op->outputs) + { + if (r->name == key) + { + operand = r; + break; + } + } + } + + if (!operand) + { + fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); + return; + } + + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + operand->type = string_to_type(typestr.c_str()); + + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); + + operand->shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if (elem == "?") + { + operand->shape.push_back(-1); + } + else + { + int i = std::stoi(elem); + operand->shape.push_back(i); + } + } +} + +static void load_attribute(Operator *op, const std::string &key, const std::string &value, StoreZipReader &szr) +{ + Attribute &a = op->attrs[key]; + + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + a.type = string_to_type(typestr.c_str()); + + if (a.type == 0) + return; + + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); + + a.shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + int i = std::stoi(elem); + a.shape.push_back(i); + } + + if (a.shape.empty()) + return; + + // data + size_t size = 1; + for (int i : a.shape) + { + size *= i; + } + + size_t bytesize = size * type_to_elemsize(a.type); + + std::string filename = op->name + "." + key; + + size_t filesize = szr.get_file_size(filename); + + if (filesize == 0) + { + // no such file + return; + } + + if (filesize != bytesize) + { + fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); + } + + a.data.resize(bytesize); + szr.read_file(filename, (char *)a.data.data()); +} + +int Graph::load(const std::string ¶mpath, const std::string &binpath) +{ + std::ifstream is(parampath, std::ios::in | std::ios::binary); + if (!is.good()) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + StoreZipReader szr; + if (szr.open(binpath) != 0) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + int magic = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> magic; + } + + int operator_count = 0; + int operand_count = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> operator_count >> operand_count; + } + + for (int i = 0; i < operator_count; i++) + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + std::string type; + std::string name; + int input_count = 0; + int output_count = 0; + + iss >> type >> name >> input_count >> output_count; + + Operator *op = new_operator(type, name); + + for (int j = 0; j < input_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand *r = get_operand(operand_name); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < output_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand *r = new_operand(operand_name); + r->producer = op; + op->outputs.push_back(r); + } + + // key=value + while (!iss.eof()) + { + std::string param; + iss >> param; + + std::string key; + std::string value; + std::istringstream pss(param); + std::getline(pss, key, '='); + std::getline(pss, value); + + if (key[0] == '@') + { + // attribute + load_attribute(op, key.substr(1), value, szr); + } + else if (key[0] == '$') + { + // operand input key + load_input_key(op, key.substr(1), value); + } + else if (key[0] == '#') + { + // operand shape + load_shape(op, key.substr(1), value); + } + else + { + // parameter + load_parameter(op, key, value); + } + } + } + + return 0; +} + +int Graph::save(const std::string ¶mpath, const std::string &binpath) +{ + FILE *paramfp = fopen(parampath.c_str(), "wb"); + if (!paramfp) + { + fprintf(stderr, "fopen %s failed\n", parampath.c_str()); + return -1; + } + + StoreZipWriter szw; + if (szw.open(binpath) != 0) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + // magic + fprintf(paramfp, "7767517\n"); + + // op count and oprand count + fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); + + for (const Operator *op : ops) + { + fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); + + for (const Operand *oprand : op->inputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const Operand *oprand : op->outputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const auto &it : op->params) + { + fprintf(paramfp, " %s=", it.first.c_str()); + + const Parameter ¶m = it.second; + if (param.type == 0) + { + fprintf(paramfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(paramfp, "True"); + else + fprintf(paramfp, "False"); + } + if (param.type == 2) + { + fprintf(paramfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(paramfp, "%e", param.f); + } + if (param.type == 4) + { + fprintf(paramfp, "%s", param.s.c_str()); + } + if (param.type == 5) + { + fprintf(paramfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(paramfp, "%d", param.ai[i]); + if (i + 1 != param.ai.size()) + fprintf(paramfp, ","); + } + fprintf(paramfp, ")"); + } + if (param.type == 6) + { + fprintf(paramfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(paramfp, "%e", param.af[i]); + if (i + 1 != param.af.size()) + fprintf(paramfp, ","); + } + fprintf(paramfp, ")"); + } + if (param.type == 7) + { + fprintf(paramfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + fprintf(paramfp, "%s", param.as[i].c_str()); + if (i + 1 != param.as.size()) + fprintf(paramfp, ","); + } + fprintf(paramfp, ")"); + } + } + + for (const auto &it : op->attrs) + { + fprintf(paramfp, " @%s=", it.first.c_str()); + + const Attribute &attr = it.second; + fprintf(paramfp, "("); + for (int i = 0; i < (int)attr.shape.size() - 1; i++) + { + fprintf(paramfp, "%d,", attr.shape[i]); + } + if (attr.shape.size() > 0) + fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); + fprintf(paramfp, ")"); + + fprintf(paramfp, "%s", type_to_string(attr.type)); + + std::string filename = op->name + "." + it.first; + szw.write_file(filename, attr.data.data(), attr.data.size()); + } + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + const Operand *oprand = op->inputs[i]; + fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); + } + } + + for (const Operand *oprand : op->inputs) + { + if (oprand->shape.empty()) + continue; + + fprintf(paramfp, " #%s=", oprand->name.c_str()); + + fprintf(paramfp, "("); + for (int i = 0; i < (int)oprand->shape.size() - 1; i++) + { + if (oprand->shape[i] == -1) + fprintf(paramfp, "?,"); + else + fprintf(paramfp, "%d,", oprand->shape[i]); + } + if (oprand->shape.size() > 0) + { + if (oprand->shape[oprand->shape.size() - 1] == -1) + fprintf(paramfp, "?"); + else + fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + } + fprintf(paramfp, ")"); + + fprintf(paramfp, "%s", type_to_string(oprand->type)); + } + + for (const Operand *oprand : op->outputs) + { + if (oprand->shape.empty()) + continue; + + fprintf(paramfp, " #%s=", oprand->name.c_str()); + + fprintf(paramfp, "("); + for (int i = 0; i < (int)oprand->shape.size() - 1; i++) + { + if (oprand->shape[i] == -1) + fprintf(paramfp, "?,"); + else + fprintf(paramfp, "%d,", oprand->shape[i]); + } + if (oprand->shape.size() > 0) + { + if (oprand->shape[oprand->shape.size() - 1] == -1) + fprintf(paramfp, "?"); + else + fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + } + fprintf(paramfp, ")"); + + fprintf(paramfp, "%s", type_to_string(oprand->type)); + } + + fprintf(paramfp, "\n"); + } + + fclose(paramfp); + + return 0; +} + +static std::string sanitize_identifier(const std::string &s) +{ + std::string ss = s; + for (size_t i = 0; i < ss.size(); i++) + { + if (ss[i] == '.' || ss[i] == ':') + ss[i] = '_'; + } + + return ss; +} + +static std::string expand_expression(const Operator *op) +{ + std::string expr = op->params.at("expr").s; + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } + + // scan and stack + std::stack exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string &t = tokens[i]; + + if (t == "size") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = a + ".size(" + b + ")"; + exprstack.push(r); + } + else if (t == "int" || t == "sqrt" || t == "rsqrt" || t == "neg") + { + std::string unaryop; + if (t == "int") + unaryop = "int"; + if (t == "sqrt") + unaryop = "torch.sqrt"; + if (t == "rsqrt") + unaryop = "torch.rsqrt"; + if (t == "neg") + unaryop = "torch.neg"; + + std::string a = exprstack.top(); + exprstack.pop(); + + std::string r = unaryop + "(" + a + ")"; + exprstack.push(r); + } + else if (t == "pow") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = a + ".pow(" + b + ")"; + exprstack.push(r); + } + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide") + { + std::string binaryop; + if (t == "add") + binaryop = "+"; + if (t == "sub") + binaryop = "-"; + if (t == "mul") + binaryop = "*"; + if (t == "div") + binaryop = "/"; + if (t == "floor_divide") + binaryop = "//"; + + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; + exprstack.push(r); + } + else if (t == "[") // list + { + std::vector elements; + while (!exprstack.empty()) + { + std::string a = exprstack.top(); + exprstack.pop(); + + elements.push_back(a); + } + + std::string r = "["; + for (int j = 0; j < (int)elements.size() - 1; j++) + { + r += elements[j]; + if (j + 1 != (int)elements.size()) + r += ", "; + } + if (!elements.empty()) + { + r += elements[elements.size() - 1]; + } + r += "]"; + + exprstack.push(r); + } + else if (t[0] == '@') + { + int input_index = std::stoi(t.substr(1)); + std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); + exprstack.push(varid); + } + else + { + // literal + exprstack.push(t); + } + } + + std::string r = exprstack.top(); + exprstack.pop(); + + return r; +} + +static std::string make_slice_expression(const Operator *op) +{ + for (size_t j = 0; j < op->inputnames.size(); j++) + { + fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str()); + } + + std::vector dims = op->params.at("dims").ai; + + std::string r; + + int last_dim = -1; + const int ndim = (int)dims.size(); + for (int i = 0; i < ndim; i++) + { + int dim = dims[i]; + for (int j = last_dim + 1; j < dim; j++) + { + r += ":,"; + } + last_dim = dim; + + if (op->params.find("starts") != op->params.end()) + { + std::vector starts = op->params.at("starts").ai; + int start = starts[i]; + + if (start != 0) + r += std::to_string(start); + } + else + { + fprintf(stderr, "find start\n"); + // find start + for (size_t j = 0; j < op->inputnames.size(); j++) + { + if (op->inputnames[j] == "start") + { + r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + + fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str()); + break; + } + } + } + + r += ':'; + + if (op->params.find("ends") != op->params.end()) + { + std::vector ends = op->params.at("ends").ai; + int end = ends[i]; + if (end != -1) + r += std::to_string(end); + } + else + { + // find end + for (size_t j = 0; j < op->inputnames.size(); j++) + { + if (op->inputnames[j] == "end") + { + r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + break; + } + } + } + + if (op->params.find("steps") != op->params.end()) + { + std::vector steps = op->params.at("steps").ai; + int step = steps[i]; + if (step != 1) + { + r += ':'; + r += std::to_string(step); + } + } + else + { + // find step + for (size_t j = 0; j < op->inputnames.size(); j++) + { + if (op->inputnames[j] == "step") + { + r += ':'; + r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + break; + } + } + } + + if (i + 1 != ndim) + r += ','; + } + + return r; +} + +int Graph::python(const std::string &pypath, const std::string &pnnxbinpath) +{ + FILE *pyfp = fopen(pypath.c_str(), "wb"); + if (!pyfp) + { + fprintf(stderr, "fopen %s failed\n", pypath.c_str()); + return -1; + } + + fprintf(pyfp, "import os\n"); + fprintf(pyfp, "import numpy as np\n"); + fprintf(pyfp, "import tempfile, zipfile\n"); + fprintf(pyfp, "import torch\n"); + fprintf(pyfp, "import torch.nn as nn\n"); + fprintf(pyfp, "import torch.nn.functional as F\n"); + + fprintf(pyfp, "\n"); + + fprintf(pyfp, "class Model(nn.Module):\n"); + fprintf(pyfp, " def __init__(self):\n"); + fprintf(pyfp, " super(Model, self).__init__()\n"); + + fprintf(pyfp, "\n"); + + // module + { + for (const Operator *op : ops) + { + if (op->type.substr(0, 3) != "nn.") + continue; + + fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str()); + + int param_count = op->params.size(); + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + param_count -= 2; // ignore scale and zero_point + } + + int param_index = 0; + for (const auto &it : op->params) + { + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + if (it.first == "scale" || it.first == "zero_point") + continue; + } + + fprintf(pyfp, "%s=", it.first.c_str()); + + const Parameter ¶m = it.second; + if (param.type == 0) + { + fprintf(pyfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(pyfp, "True"); + else + fprintf(pyfp, "False"); + } + if (param.type == 2) + { + fprintf(pyfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(pyfp, "%f", param.f); + } + if (param.type == 4) + { + if (param.s.substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.s.c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.s.c_str()); + } + } + if (param.type == 5) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(pyfp, "%d", param.ai[i]); + if (i + 1 != param.ai.size() || param.ai.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 6) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(pyfp, "%f", param.af[i]); + if (i + 1 != param.af.size() || param.af.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 7) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + if (param.as[i].substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.as[i].c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.as[i].c_str()); + } + if (i + 1 != param.as.size() || param.as.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + + param_index++; + if (param_index != param_count) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, ")\n"); + } + } + + fprintf(pyfp, "\n"); + + // load weights + { + fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); + + for (const Operator *op : ops) + { + if (op->type.substr(0, 3) != "nn.") + continue; + + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + for (const auto &it : op->attrs) + { + if (it.first == "weight" || it.first == "bias") + { + fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + else + { + // unknown attr + continue; + } + + const Attribute &attr = it.second; + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d", attr.shape[i]); + if (i + 1 != attr.shape.size()) + fprintf(pyfp, ","); + } + + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } + + fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str()); + fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f); + fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i); + + continue; + } + + for (const auto &it : op->attrs) + { + if (it.first == "running_mean" || it.first == "running_var") + { + fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + else + { + fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + + const Attribute &attr = it.second; + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d", attr.shape[i]); + if (i + 1 != attr.shape.size()) + fprintf(pyfp, ","); + } + + fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + } + } + + fprintf(pyfp, " archive.close()\n"); + } + + fprintf(pyfp, "\n"); + + // utility function + { + fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n"); + fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); + fprintf(pyfp, " _, tmppath = tempfile.mkstemp()\n"); + fprintf(pyfp, " tmpf = open(tmppath, 'wb')\n"); + fprintf(pyfp, " with archive.open(key) as keyfile:\n"); + fprintf(pyfp, " tmpf.write(keyfile.read())\n"); + fprintf(pyfp, " tmpf.close()\n"); + fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); + fprintf(pyfp, " os.remove(tmppath)\n"); + fprintf(pyfp, " return torch.from_numpy(m)\n"); + } + + fprintf(pyfp, "\n"); + + // def forward + { + fprintf(pyfp, " def forward(self"); + + for (const Operator *op : ops) + { + if (op->type != "pnnx.Input") + continue; + + fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + } + + fprintf(pyfp, "):\n"); + } + + // forward body + { + for (const Operator *op : ops) + { + if (op->type == "pnnx.Input" || op->type == "pnnx.Output") + continue; + + fprintf(pyfp, " "); + + if (op->type == "pnnx.Expression") + { + // expr + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + std::string expanded_expr = expand_expression(op); + fprintf(pyfp, " = %s\n", expanded_expr.c_str()); + } + else if (op->type == "Tensor.slice") + { + // slice expr + std::string slice_expr = make_slice_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); + } + else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") + { + // view reshape + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector &shape = op->params.at("shape").ai; + for (size_t i = 0; i < shape.size(); i++) + { + fprintf(pyfp, "%d", shape[i]); + if (i + 1 != shape.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "Tensor.repeat") + { + // view reshape + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector &sizes = op->params.at("sizes").ai; + for (size_t i = 0; i < sizes.size(); i++) + { + fprintf(pyfp, "%d", sizes[i]); + if (i + 1 != sizes.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "torch.cat") + { + // cat + fprintf(pyfp, "v_%s = torch.cat(", sanitize_identifier(op->outputs[0]->name).c_str()); + if (op->inputs.size() == 1) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "("); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")"); + } + fprintf(pyfp, ", dim=%d", op->params.at("dim").i); + fprintf(pyfp, ")\n"); + } + else if (op->type == "prim::TupleUnpack") + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else if (op->type == "prim::TupleConstruct") + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + fprintf(pyfp, " = ("); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "prim::ListUnpack") + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else if (op->type == "prim::ListConstruct") + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + fprintf(pyfp, " = ["); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "]\n"); + } + else if (op->type == "nn.LSTM") + { + if (op->outputs.size() == 1) + { + fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str()); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (op->inputs.size() == 3) + { + fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type.substr(0, 3) == "nn.") + { + // self.xxx() + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + } + else if (op->type.find("::") != std::string::npos || op->type.find(".") != std::string::npos) + { + // direct + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + + if (op->type.substr(0, 7) == "Tensor.") + { + fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + } + else + { + fprintf(pyfp, " = %s(", op->type.c_str()); + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (!op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + else + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + } + + int i = 0; + for (const auto &it : op->params) + { + if (op->type.substr(0, 7) == "Tensor." && i == 0) + { + fprintf(pyfp, "%s=", it.first.c_str()); + } + else + { + fprintf(pyfp, ", %s=", it.first.c_str()); + } + + i++; + + const Parameter ¶m = it.second; + if (param.type == 0) + { + fprintf(pyfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(pyfp, "True"); + else + fprintf(pyfp, "False"); + } + if (param.type == 2) + { + fprintf(pyfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(pyfp, "%f", param.f); + } + if (param.type == 4) + { + if (param.s.substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.s.c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.s.c_str()); + } + } + if (param.type == 5) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(pyfp, "%d", param.ai[i]); + if (i + 1 != param.ai.size() || param.ai.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 6) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(pyfp, "%f", param.af[i]); + if (i + 1 != param.af.size() || param.af.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 7) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + if (param.as[i].substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.as[i].c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.as[i].c_str()); + } + if (i + 1 != param.as.size() || param.as.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + } + + fprintf(pyfp, ")\n"); + } + else + { + fprintf(stderr, "todo %s\n", op->type.c_str()); + } + } + } + + // return + { + fprintf(pyfp, " return "); + + int output_count = 0; + { + for (const Operator *op : ops) + { + if (op->type == "pnnx.Output") + output_count++; + } + } + + int output_index = 0; + for (const Operator *op : ops) + { + if (op->type != "pnnx.Output") + continue; + + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (output_index + 1 != output_count) + fprintf(pyfp, ", "); + + output_index++; + } + + fprintf(pyfp, "\n"); + } + + fprintf(pyfp, "\n"); + + // export torchscript + { + fprintf(pyfp, "def export_torchscript():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator *op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand *r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + + input_names.push_back(input_name); + } + + fprintf(pyfp, "\n"); + + if (input_names.size() == 1) + { + fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " mod = torch.jit.trace(net, ("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, "))\n"); + } + + fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); + } + + fprintf(pyfp, "\n"); + + // test inference + { + fprintf(pyfp, "def test_inference():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator *op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand *r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + + input_names.push_back(input_name); + } + + fprintf(pyfp, "\n"); + + if (input_names.size() == 1) + { + fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " return net("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, ")\n"); + } + } + + fclose(pyfp); + + return 0; +} + +static bool string_is_positive_integer(const std::string &t) +{ + for (size_t i = 0; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +int Graph::parse(const std::string ¶m) +{ + std::istringstream is(param); + if (!is.good()) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + int magic = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> magic; + } + + int operator_count = 0; + int operand_count = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> operator_count >> operand_count; + } + + for (int i = 0; i < operator_count; i++) + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + std::string type; + std::string name; + int input_count = 0; + int output_count = 0; + + iss >> type >> name >> input_count >> output_count; + + Operator *op = new_operator(type, name); + + for (int j = 0; j < input_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand *r = get_operand(operand_name); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < output_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand *r = new_operand(operand_name); + r->producer = op; + op->outputs.push_back(r); + } + + // key=value + while (!iss.eof()) + { + std::string param; + iss >> param; + + std::string key; + std::string value; + std::istringstream pss(param); + std::getline(pss, key, '='); + std::getline(pss, value); + + if (key[0] == '@') + { + // attribute + // load_attribute(op, key.substr(1), value, szr); + } + else if (key[0] == '$') + { + // operand input key + // load_input_key(op, key.substr(1), value); + } + else if (key[0] == '#') + { + // operand shape + load_shape(op, key.substr(1), value); + } + else + { + // parameter + load_parameter(op, key, value); + } + } + } + + return 0; +} + +void Operand::remove_consumer(const Operator *c) +{ + auto it = std::find(consumers.begin(), consumers.end(), c); + consumers.erase(it); +} + +Operator *Graph::new_operator(const std::string &type, const std::string &name) +{ + Operator *op = new Operator; + op->type = type; + op->name = name; + ops.push_back(op); + return op; +} + +Operator *Graph::new_operator_before(const std::string &type, const std::string &name, const Operator *cur) +{ + Operator *op = new Operator; + op->type = type; + op->name = name; + ops.insert(std::find(ops.begin(), ops.end(), cur), op); + return op; +} + +Operand *Graph::new_operand(const std::string &name) +{ + Operand *r = new Operand; + r->name = name; + operands.push_back(r); + return r; +} + +Operand *Graph::get_operand(const std::string &name) +{ + for (Operand *r : operands) + { + if (r->name == name) + return r; + } + + return 0; +} + +} // namespace pnnx diff --git a/src/importer/pnnx/ir.h b/src/importer/pnnx/ir.h new file mode 100644 index 000000000..76c6084a5 --- /dev/null +++ b/src/importer/pnnx/ir.h @@ -0,0 +1,238 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_IR_H +#define PNNX_IR_H + +#include +#include +#include +#include + +#include "nncase/ir/ir_types.h" + +namespace pnnx +{ + +class Parameter +{ +public: + Parameter() + : type(0) + { + } + Parameter(bool _b) + : type(1), b(_b) + { + } + Parameter(int _i) + : type(2), i(_i) + { + } + Parameter(long _l) + : type(2), i(_l) + { + } + Parameter(long long _l) + : type(2), i(_l) + { + } + Parameter(float _f) + : type(3), f(_f) + { + } + Parameter(double _d) + : type(3), f(_d) + { + } + Parameter(const char *_s) + : type(4), s(_s) + { + } + Parameter(const std::string &_s) + : type(4), s(_s) + { + } + Parameter(const std::initializer_list &_ai) + : type(5), ai(_ai) + { + } + Parameter(const std::initializer_list &_ai) + : type(5) + { + for (const auto &x : _ai) + ai.push_back((int)x); + } + Parameter(const std::vector &_ai) + : type(5), ai(_ai) + { + } + Parameter(const std::initializer_list &_af) + : type(6), af(_af) + { + } + Parameter(const std::initializer_list &_af) + : type(6) + { + for (const auto &x : _af) + af.push_back((float)x); + } + Parameter(const std::vector &_af) + : type(6), af(_af) + { + } + Parameter(const std::initializer_list &_as) + : type(7) + { + for (const auto &x : _as) + as.push_back(std::string(x)); + } + Parameter(const std::initializer_list &_as) + : type(7), as(_as) + { + } + Parameter(const std::vector &_as) + : type(7), as(_as) + { + } + + static Parameter parse_from_string(const std::string &value); + + // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others + int type; + + // value + bool b; + int i; + float f; + std::string s; + std::vector ai; + std::vector af; + std::vector as; +}; + +class Attribute +{ +public: + Attribute() + : type(0) + { + } + + Attribute(const std::initializer_list &shape, const std::vector &t); + + nncase::ir::shape_t get_shape() const + { + nncase::ir::shape_t s; + for (auto v : shape) + s.push_back(v); + return s; + } + + std::span get_data() const + { + return std::span { (float *)data.data(), (size_t)data.size() / sizeof(float) }; + } + + // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 + int type; + std::vector shape; + + std::vector data; +}; + +class Operator; +class Operand +{ +public: + void remove_consumer(const Operator *c); + + nncase::ir::shape_t get_shape() const + { + nncase::ir::shape_t s; + for (auto v : shape) + s.push_back(v); + return s; + } + + std::string name; + + Operator *producer; + std::vector consumers; + + // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 + int type; + std::vector shape; + + std::map params; + +private: + friend class Graph; + Operand() + { + } +}; + +class Operator +{ +public: + std::string type; + std::string name; + + std::vector inputs; + std::vector outputs; + + std::vector inputnames; + std::map params; + std::map attrs; + +private: + friend class Graph; + Operator() + { + } +}; + +class Graph +{ +public: + Graph(); + ~Graph(); + + int load(const std::string ¶mpath, const std::string &binpath); + int save(const std::string ¶mpath, const std::string &binpath); + + int python(const std::string &pypath, const std::string &binpath); + + int parse(const std::string ¶m); + + Operator *new_operator(const std::string &type, const std::string &name); + + Operator *new_operator_before(const std::string &type, const std::string &name, const Operator *cur); + + Operand *new_operand(const std::string &name); + + Operand *get_operand(const std::string &name); + + std::vector ops; + std::vector operands; + +private: + Graph(const Graph &rhs); + Graph &operator=(const Graph &rhs); +}; + +} // namespace pnnx + +#endif // PNNX_IR_H diff --git a/src/importer/pnnx/opcode.def b/src/importer/pnnx/opcode.def new file mode 100644 index 000000000..1975efeba --- /dev/null +++ b/src/importer/pnnx/opcode.def @@ -0,0 +1,9 @@ +DEFINE_OPCODE(pnnx.Input, pnnx_Input) +DEFINE_OPCODE(pnnx.Output, pnnx_Output) + +DEFINE_OPCODE(F.relu, F_relu) +DEFINE_OPCODE(F.relu6, F_relu6) + +DEFINE_OPCODE(nn.Conv2d, nn_Conv2d) +DEFINE_OPCODE(nn.ReLU, nn_ReLU) +DEFINE_OPCODE(nn.ReLU6, nn_ReLU6) diff --git a/src/importer/pnnx/ops/conv2d.cpp b/src/importer/pnnx/ops/conv2d.cpp new file mode 100644 index 000000000..37fc02431 --- /dev/null +++ b/src/importer/pnnx/ops/conv2d.cpp @@ -0,0 +1,112 @@ +// Tencent is pleased to support the open source community by making pnnx available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "../pnnx_importer.h" +#include "nncase/importer/util.h" +#include "nncase/ir/ir_types.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::importer; +using namespace nncase::ir; +using namespace pnnx; + +void nncase::importer::pnnx_importer::convert_op_nn_Conv2d(const Operator &op) +{ + const auto &op_name = op.name; + + auto in_shape = op.inputs[0]->get_shape(); + auto weight_shape = op.attrs.at("op_0.weight").get_shape(); + + const int dilation_w = op.params.at("dilation").ai[1]; + const int dilation_h = op.params.at("dilation").ai[0]; + const int stride_w = op.params.at("stride").ai[1]; + const int stride_h = op.params.at("stride").ai[0]; + const int group = op.params.at("group").i; + std::string padding_mode = op.params.at("padding_mode").s; + + padding padding_w; + padding padding_h; + if (op.params.at("padding").type == 4) + { + if (op.params.at("padding").s == "same") + { + padding_w = get_windowed_padding(in_shape[3], weight_shape[3], stride_w, dilation_w, true); + padding_h = get_windowed_padding(in_shape[2], weight_shape[2], stride_h, dilation_h, true); + } + else // if (op.params.at("padding").s == "valid") + { + padding_w = { 0, 0 }; + padding_h = { 0, 0 }; + } + } + else + { + padding_w = { op.params.at("padding").ai[1], op.params.at("padding").ai[1] }; + padding_h = { op.params.at("padding").ai[0], op.params.at("padding").ai[0] }; + } + + ir::pad *pad_op = 0; + if (padding_mode == "reflect" || padding_mode == "replicate") + { + xt::svector paddings = { { 0, 0 }, { 0, 0 }, padding_h, padding_w }; + pad_mode_t pad_mode = padding_mode == "reflect" ? pad_reflect : pad_edge; + + pad_op = graph_.emplace(dt_float32, in_shape, paddings, pad_mode, 0.f); + pad_op->name(op_name + ".pad(Convolution)"); + + padding_w = { 0, 0 }; + padding_h = { 0, 0 }; + } + + ir::conv2d *conv_op = graph_.emplace(in_shape, weight_shape, group, padding_h, padding_w, stride_h, stride_w, dilation_h, dilation_w, value_range::full()); + conv_op->name(op_name + ".conv2d(Conv2d)"); + + if (pad_op) + { + conv_op->input().connect(pad_op->output()); + } + + auto weight_data = op.attrs.at("op_0.weight").get_data(); + + auto weight_node = graph_.emplace(dt_float32, weight_shape, weight_data); + conv_op->weights().connect(weight_node->output()); + + if (op.params.at("bias").b) + { + auto bias_shape = op.attrs.at("op_0.bias").get_shape(); + auto bias_data = op.attrs.at("op_0.bias").get_data(); + + auto bias_node = graph_.emplace(dt_float32, bias_shape, bias_data); + conv_op->bias().connect(bias_node->output()); + } + + if (pad_op) + { + input_tensors_.emplace(&pad_op->input(), op.inputs[0]->name); + } + else + { + input_tensors_.emplace(&conv_op->input(), op.inputs[0]->name); + } + + output_tensors_.emplace(op.outputs[0]->name, &conv_op->output()); +} diff --git a/src/importer/pnnx/ops/input.cpp b/src/importer/pnnx/ops/input.cpp new file mode 100644 index 000000000..763f50038 --- /dev/null +++ b/src/importer/pnnx/ops/input.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making pnnx available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "../pnnx_importer.h" +#include "nncase/importer/util.h" +#include "nncase/ir/ir_types.h" +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::importer; +using namespace nncase::ir; +using namespace pnnx; + +void nncase::importer::pnnx_importer::convert_op_pnnx_Input(const Operator &op) +{ + const auto &op_name = op.name; + + for (auto r : op.outputs) + { + auto in_shape = r->get_shape(); + + auto node = graph_.emplace(dt_float32, in_shape); + node->name(op_name + "." + r->name + "(Input)"); + + output_tensors_.emplace(r->name, &node->output()); + } +} diff --git a/src/importer/pnnx/ops/output.cpp b/src/importer/pnnx/ops/output.cpp new file mode 100644 index 000000000..7653b9afa --- /dev/null +++ b/src/importer/pnnx/ops/output.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making pnnx available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "../pnnx_importer.h" +#include "nncase/importer/util.h" +#include "nncase/ir/ir_types.h" +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::importer; +using namespace nncase::ir; +using namespace pnnx; + +void nncase::importer::pnnx_importer::convert_op_pnnx_Output(const Operator &op) +{ + const auto &op_name = op.name; + + for (auto r : op.inputs) + { + auto in_shape = r->get_shape(); + + auto node = graph_.emplace(dt_float32, in_shape); + node->name(op_name + "." + r->name + "(Output)"); + + input_tensors_.emplace(&node->input(), r->name); + } +} diff --git a/src/importer/pnnx/ops/relu.cpp b/src/importer/pnnx/ops/relu.cpp new file mode 100644 index 000000000..81a1a65ca --- /dev/null +++ b/src/importer/pnnx/ops/relu.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making pnnx available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "../pnnx_importer.h" +#include "nncase/importer/util.h" +#include "nncase/ir/ir_types.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::importer; +using namespace nncase::ir; +using namespace pnnx; + +void nncase::importer::pnnx_importer::convert_op_F_relu(const Operator &op) +{ + const auto &op_name = op.name; + + auto in_shape = op.inputs[0]->get_shape(); + + auto zero = graph_.emplace(0.f); + zero->name(op_name + ".zero(ReLU)"); + + auto max = graph_.emplace(binary_max, in_shape, zero->output().shape(), value_range::full()); + max->name(op_name + ".max(ReLU)"); + + max->input_b().connect(zero->output()); + + input_tensors_.emplace(&max->input_a(), op.inputs[0]->name); + output_tensors_.emplace(op.outputs[0]->name, &max->output()); +} + +void nncase::importer::pnnx_importer::convert_op_nn_ReLU(const Operator &op) +{ + convert_op_F_relu(op); +} diff --git a/src/importer/pnnx/ops/relu6.cpp b/src/importer/pnnx/ops/relu6.cpp new file mode 100644 index 000000000..b0ee0f9c3 --- /dev/null +++ b/src/importer/pnnx/ops/relu6.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making pnnx available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "../pnnx_importer.h" +#include "nncase/importer/util.h" +#include "nncase/ir/ir_types.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::importer; +using namespace nncase::ir; +using namespace pnnx; + +void nncase::importer::pnnx_importer::convert_op_F_relu6(const Operator &op) +{ + const auto &op_name = op.name; + + auto in_shape = op.inputs[0]->get_shape(); + + auto zero = graph_.emplace(0.f); + zero->name(op_name + ".zero(ReLU6)"); + auto six = graph_.emplace(6.f); + six->name(op_name + ".six(ReLU6)"); + + auto clamp_op = graph_.emplace(in_shape, zero->output().shape(), six->output().shape()); + clamp_op->name(op_name + ".clamp(ReLU6)"); + + clamp_op->input_low().connect(zero->output()); + clamp_op->input_high().connect(six->output()); + + input_tensors_.emplace(&clamp_op->input(), op.inputs[0]->name); + output_tensors_.emplace(op.outputs[0]->name, &clamp_op->output()); +} + +void nncase::importer::pnnx_importer::convert_op_nn_ReLU6(const Operator &op) +{ + convert_op_F_relu6(op); +} diff --git a/src/importer/pnnx/pnnx_importer.cpp b/src/importer/pnnx/pnnx_importer.cpp new file mode 100644 index 000000000..ad762bb89 --- /dev/null +++ b/src/importer/pnnx/pnnx_importer.cpp @@ -0,0 +1,175 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pnnx_importer.h" +#include "ir.h" +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace nncase; +using namespace nncase::importer; +using namespace nncase::ir; + +namespace pnnx +{ + +void chain_multi_output(Graph &graph) +{ + for (;;) + { + bool need_eliminate = false; + + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator *op = graph.ops[i]; + + if (op->type != "pnnx.Output") + continue; + + // prim::TupleConstruct pnnx_791 2 1 a b out + // pnnx.Expression pnnx_expr_0 3 1 a b c out expr=[@0,@1,@2] + // pnnx.Output pnnx_output_0 1 0 out + bool match_tuple_expr_output = false; + for (int j = 0; j < (int)op->inputs.size(); j++) + { + Operand *r = op->inputs[j]; + + if (r->consumers.size() != 1) + continue; + + Operator *op0 = r->producer; + + if (op0->type == "prim::TupleConstruct") + { + match_tuple_expr_output = true; + } + else if (op0->type == "pnnx.Expression") + { + const int op_expr_input_count = (int)op0->inputs.size(); + const std::string &expr = op0->params.at("expr").s; + + std::string pattern_expr = "["; + for (int k = 0; k < op_expr_input_count; k++) + { + pattern_expr += std::string("@") + std::to_string(k); + + if (k != op_expr_input_count - 1) + pattern_expr += ","; + } + pattern_expr += "]"; + + if (expr == pattern_expr) + { + match_tuple_expr_output = true; + } + } + + if (!match_tuple_expr_output) + continue; + + // chain op0 as output and delete op0 + std::vector new_inputs; + for (int k = 0; k < j; k++) + { + new_inputs.push_back(op->inputs[k]); + } + + for (Operand *r : op0->inputs) + { + r->remove_consumer(op0); + r->consumers.push_back(op); + new_inputs.push_back(r); + } + + for (int k = j + 1; k < (int)op->inputs.size(); k++) + { + new_inputs.push_back(op->inputs[k]); + } + + op->inputs = new_inputs; + + op0->inputs.clear(); + op0->outputs.clear(); + + Operand *op0_out = op0->outputs[0]; + op0_out->producer = 0; + op0_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op0_out)); + delete op0_out; + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op0)); + delete op0; + + break; + } + + if (match_tuple_expr_output) + need_eliminate = true; + + break; + } + + if (!need_eliminate) + break; + } +} + +} // namespace pnnx + +pnnx_importer::pnnx_importer(std::string parampath, std::string binpath, ir::graph &graph) + : graph_(graph) +{ + pnnx_graph_.load(parampath, binpath); + + pnnx::chain_multi_output(pnnx_graph_); +} + +void pnnx_importer::import(const struct import_options & /*options*/, std::string & /*real_inlayout*/, std::string & /*real_outlayout*/) +{ + for (const pnnx::Operator *op : pnnx_graph_.ops) + { + convert_op(*op); + } + + // connect tensors + for (auto &&in : input_tensors_) + { + auto out_it = output_tensors_.find(in.second); + if (out_it != output_tensors_.end()) + { + in.first->connect(*out_it->second); + } + else + { + assert(!"Cannot find associated output node"); + } + } +} + +void pnnx_importer::convert_op(const pnnx::Operator &op) +{ +#define DEFINE_OPCODE(opcode, opcode2) \ + if (op.type == #opcode##sv) \ + return convert_op_##opcode2(op); +#include "opcode.def" +#undef DEFINE_OPCODE + + throw std::runtime_error("Not supported pnnx opcode: " + op.type); +} diff --git a/src/importer/pnnx/pnnx_importer.h b/src/importer/pnnx/pnnx_importer.h new file mode 100644 index 000000000..c312f548e --- /dev/null +++ b/src/importer/pnnx/pnnx_importer.h @@ -0,0 +1,60 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "ir.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nncase +{ +namespace importer +{ + class pnnx_importer + { + public: + pnnx_importer(std::string parampath, std::string binpath, ir::graph &graph); + + void import(const struct import_options &options, std::string &real_inlayout, std::string &real_outlayout); + + private: + void convert_op(const pnnx::Operator &op); + +#define DEFINE_OPCODE(opcode, opcode2) void convert_op_##opcode2(const pnnx::Operator &op); +#include "opcode.def" +#undef DEFINE_OPCODE + + private: + ir::graph &graph_; + pnnx::Graph pnnx_graph_; + std::unordered_map input_tensors_; + std::unordered_map output_tensors_; + }; +} +} + +#define DEFINE_PNNX_LOWER(opcode) \ + void nncase::importer::pnnx_importer::convert_op_##opcode(const pnnx::Operator &op) diff --git a/src/importer/pnnx/storezip.cpp b/src/importer/pnnx/storezip.cpp new file mode 100644 index 000000000..1a6c554e2 --- /dev/null +++ b/src/importer/pnnx/storezip.cpp @@ -0,0 +1,406 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "storezip.h" + +#include +#include +#include +#include +#include + +namespace pnnx +{ + +// https://stackoverflow.com/questions/1537964/visual-c-equivalent-of-gccs-attribute-packed +#ifdef _MSC_VER +#define PACK(__Declaration__) __pragma(pack(push, 1)) __Declaration__ __pragma(pack(pop)) +#else +#define PACK(__Declaration__) __Declaration__ __attribute__((__packed__)) +#endif + +PACK(struct local_file_header { + uint16_t version; + uint16_t flag; + uint16_t compression; + uint16_t last_modify_time; + uint16_t last_modify_date; + uint32_t crc32; + uint32_t compressed_size; + uint32_t uncompressed_size; + uint16_t file_name_length; + uint16_t extra_field_length; +}); + +PACK(struct central_directory_file_header { + uint16_t version_made; + uint16_t version; + uint16_t flag; + uint16_t compression; + uint16_t last_modify_time; + uint16_t last_modify_date; + uint32_t crc32; + uint32_t compressed_size; + uint32_t uncompressed_size; + uint16_t file_name_length; + uint16_t extra_field_length; + uint16_t file_comment_length; + uint16_t start_disk; + uint16_t internal_file_attrs; + uint32_t external_file_attrs; + uint32_t lfh_offset; +}); + +PACK(struct end_of_central_directory_record { + uint16_t disk_number; + uint16_t start_disk; + uint16_t cd_records; + uint16_t total_cd_records; + uint32_t cd_size; + uint32_t cd_offset; + uint16_t comment_length; +}); + +static uint32_t CRC32_TABLE[256]; + +static void CRC32_TABLE_INIT() +{ + for (int i = 0; i < 256; i++) + { + uint32_t c = i; + for (int j = 0; j < 8; j++) + { + if (c & 1) + c = (c >> 1) ^ 0xedb88320; + else + c >>= 1; + } + CRC32_TABLE[i] = c; + } +} + +static uint32_t CRC32(uint32_t x, unsigned char ch) +{ + return (x >> 8) ^ CRC32_TABLE[(x ^ ch) & 0xff]; +} + +static uint32_t CRC32_buffer(const unsigned char *data, int len) +{ + uint32_t x = 0xffffffff; + + for (int i = 0; i < len; i++) + x = CRC32(x, data[i]); + + return x ^ 0xffffffff; +} + +StoreZipReader::StoreZipReader() +{ + fp = 0; +} + +StoreZipReader::~StoreZipReader() +{ + close(); +} + +int StoreZipReader::open(const std::string &path) +{ + close(); + + fp = fopen(path.c_str(), "rb"); + if (!fp) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + while (!feof(fp)) + { + // peek signature + uint32_t signature; + size_t nread = fread((char *)&signature, sizeof(signature), 1, fp); + if (nread != 1) + break; + + if (signature == 0x04034b50) + { + local_file_header lfh; + nread = fread((char *)&lfh, sizeof(lfh), 1, fp); + if (nread != 1) + break; + + if (lfh.flag & 0x08) + { + fprintf(stderr, "zip file contains data descriptor, this is not supported yet\n"); + return -1; + } + + if (lfh.compression != 0 || lfh.compressed_size != lfh.uncompressed_size) + { + fprintf(stderr, "not stored zip file %d %d\n", lfh.compressed_size, lfh.uncompressed_size); + return -1; + } + + // file name + std::string name; + name.resize(lfh.file_name_length); + nread = fread((char *)name.data(), name.size(), 1, fp); + if (nread != 1) + break; + + // skip extra field + fseek(fp, lfh.extra_field_length, SEEK_CUR); + + StoreZipMeta fm; + fm.offset = ftell(fp); + fm.size = lfh.compressed_size; + + filemetas[name] = fm; + + // fprintf(stderr, "%s = %d %d\n", name.c_str(), fm.offset, fm.size); + + fseek(fp, lfh.compressed_size, SEEK_CUR); + } + else if (signature == 0x02014b50) + { + central_directory_file_header cdfh; + nread = fread((char *)&cdfh, sizeof(cdfh), 1, fp); + if (nread != 1) + break; + + // skip file name + fseek(fp, cdfh.file_name_length, SEEK_CUR); + + // skip extra field + fseek(fp, cdfh.extra_field_length, SEEK_CUR); + + // skip file comment + fseek(fp, cdfh.file_comment_length, SEEK_CUR); + } + else if (signature == 0x06054b50) + { + end_of_central_directory_record eocdr; + nread = fread((char *)&eocdr, sizeof(eocdr), 1, fp); + if (nread != 1) + break; + + // skip comment + fseek(fp, eocdr.comment_length, SEEK_CUR); + } + else + { + fprintf(stderr, "unsupported signature %x\n", signature); + return -1; + } + } + + return 0; +} + +size_t StoreZipReader::get_file_size(const std::string &name) +{ + if (filemetas.find(name) == filemetas.end()) + { + fprintf(stderr, "no such file %s\n", name.c_str()); + return 0; + } + + return filemetas[name].size; +} + +int StoreZipReader::read_file(const std::string &name, char *data) +{ + if (filemetas.find(name) == filemetas.end()) + { + fprintf(stderr, "no such file %s\n", name.c_str()); + return -1; + } + + size_t offset = filemetas[name].offset; + size_t size = filemetas[name].size; + + fseek(fp, offset, SEEK_SET); + size_t nread = fread(data, size, 1, fp); + if (nread != 1) + return -1; + + return 0; +} + +int StoreZipReader::close() +{ + if (!fp) + return 0; + + fclose(fp); + fp = 0; + + return 0; +} + +StoreZipWriter::StoreZipWriter() +{ + fp = 0; + + CRC32_TABLE_INIT(); +} + +StoreZipWriter::~StoreZipWriter() +{ + close(); +} + +int StoreZipWriter::open(const std::string &path) +{ + close(); + + fp = fopen(path.c_str(), "wb"); + if (!fp) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + return 0; +} + +int StoreZipWriter::write_file(const std::string &name, const char *data, size_t size) +{ + int offset = ftell(fp); + + uint32_t signature = 0x04034b50; + fwrite((char *)&signature, sizeof(signature), 1, fp); + + uint32_t crc32 = CRC32_buffer((const unsigned char *)data, size); + + local_file_header lfh; + lfh.version = 0; + lfh.flag = 0; + lfh.compression = 0; + lfh.last_modify_time = 0; + lfh.last_modify_date = 0; + lfh.crc32 = crc32; + lfh.compressed_size = size; + lfh.uncompressed_size = size; + lfh.file_name_length = name.size(); + lfh.extra_field_length = 0; + + fwrite((char *)&lfh, sizeof(lfh), 1, fp); + + fwrite((char *)name.c_str(), name.size(), 1, fp); + + fwrite(data, size, 1, fp); + + StoreZipMeta szm; + szm.name = name; + szm.lfh_offset = offset; + szm.crc32 = crc32; + szm.size = size; + + filemetas.push_back(szm); + + return 0; +} + +int StoreZipWriter::close() +{ + if (!fp) + return 0; + + int offset = ftell(fp); + + for (const StoreZipMeta &szm : filemetas) + { + uint32_t signature = 0x02014b50; + fwrite((char *)&signature, sizeof(signature), 1, fp); + + central_directory_file_header cdfh; + cdfh.version_made = 0; + cdfh.version = 0; + cdfh.flag = 0; + cdfh.compression = 0; + cdfh.last_modify_time = 0; + cdfh.last_modify_date = 0; + cdfh.crc32 = szm.crc32; + cdfh.compressed_size = szm.size; + cdfh.uncompressed_size = szm.size; + cdfh.file_name_length = szm.name.size(); + cdfh.extra_field_length = 0; + cdfh.file_comment_length = 0; + cdfh.start_disk = 0; + cdfh.internal_file_attrs = 0; + cdfh.external_file_attrs = 0; + cdfh.lfh_offset = szm.lfh_offset; + + fwrite((char *)&cdfh, sizeof(cdfh), 1, fp); + + fwrite((char *)szm.name.c_str(), szm.name.size(), 1, fp); + } + + int offset2 = ftell(fp); + + { + uint32_t signature = 0x06054b50; + fwrite((char *)&signature, sizeof(signature), 1, fp); + + end_of_central_directory_record eocdr; + eocdr.disk_number = 0; + eocdr.start_disk = 0; + eocdr.cd_records = filemetas.size(); + eocdr.total_cd_records = filemetas.size(); + eocdr.cd_size = offset2 - offset; + eocdr.cd_offset = offset; + eocdr.comment_length = 0; + + fwrite((char *)&eocdr, sizeof(eocdr), 1, fp); + } + + fclose(fp); + fp = 0; + + return 0; +} + +} // namespace pnnx + +#if 0 +int main() +{ + StoreZipReader sz; + + sz.open("test.zip"); + + std::vector data1; + sz.read_file("pnnx2.py", data1); + + std::vector data2; + sz.read_file("pnnx2.param", data2); + + sz.close(); + + + StoreZipWriter szw; + + szw.open("szw.zip"); + + szw.write_file("a.py", data1); + szw.write_file("zzzz.param", data2); + + szw.close(); + + + return 0; +} +#endif diff --git a/src/importer/pnnx/storezip.h b/src/importer/pnnx/storezip.h new file mode 100644 index 000000000..644db2421 --- /dev/null +++ b/src/importer/pnnx/storezip.h @@ -0,0 +1,79 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_STOREZIP_H +#define PNNX_STOREZIP_H + +#include +#include +#include + +namespace pnnx +{ + +class StoreZipReader +{ +public: + StoreZipReader(); + ~StoreZipReader(); + + int open(const std::string &path); + + size_t get_file_size(const std::string &name); + + int read_file(const std::string &name, char *data); + + int close(); + +private: + FILE *fp; + + struct StoreZipMeta + { + size_t offset; + size_t size; + }; + + std::map filemetas; +}; + +class StoreZipWriter +{ +public: + StoreZipWriter(); + ~StoreZipWriter(); + + int open(const std::string &path); + + int write_file(const std::string &name, const char *data, size_t size); + + int close(); + +private: + FILE *fp; + + struct StoreZipMeta + { + std::string name; + size_t lfh_offset; + uint32_t crc32; + uint32_t size; + }; + + std::vector filemetas; +}; + +} // namespace pnnx + +#endif // PNNX_STOREZIP_H diff --git a/src/nncase/compiler.cpp b/src/nncase/compiler.cpp index bb9816924..72310c4d5 100644 --- a/src/nncase/compiler.cpp +++ b/src/nncase/compiler.cpp @@ -169,6 +169,13 @@ class compiler_impl : public compiler END_IMPORT() } + void import_pnnx(std::string parampath, std::string binpath, const import_options &options) override + { + BEGIN_IMPORT() + importer::import_pnnx(graph_, parampath, binpath, imp_options, real_inlayout_, real_outlayout_); + END_IMPORT() + } + void use_ptq(ptq_dataset_options options) override { ptq_options_ = std::move(options);