forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_op_library.cc
157 lines (119 loc) · 5 KB
/
custom_op_library.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include "custom_op_library.h"
#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT
#include <vector>
#include <cmath>
#include <mutex>
static const char* c_OpDomain = "test.customop";
struct OrtCustomOpDomainDeleter {
explicit OrtCustomOpDomainDeleter(const OrtApi* ort_api) {
ort_api_ = ort_api;
}
void operator()(OrtCustomOpDomain* domain) const {
ort_api_->ReleaseCustomOpDomain(domain);
}
const OrtApi* ort_api_;
};
using OrtCustomOpDomainUniquePtr = std::unique_ptr<OrtCustomOpDomain, OrtCustomOpDomainDeleter>;
static std::vector<OrtCustomOpDomainUniquePtr> ort_custom_op_domain_container;
static std::mutex ort_custom_op_domain_mutex;
static void AddOrtCustomOpDomainToContainer(OrtCustomOpDomain* domain, const OrtApi* ort_api) {
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex);
auto ptr = std::unique_ptr<OrtCustomOpDomain, OrtCustomOpDomainDeleter>(domain, OrtCustomOpDomainDeleter(ort_api));
ort_custom_op_domain_container.push_back(std::move(ptr));
}
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
};
struct KernelOne {
KernelOne(OrtApi api)
: api_(api),
ort_(api_) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
const float* X = ort_.GetTensorData<float>(input_X);
const float* Y = ort_.GetTensorData<float>(input_Y);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = ort_.GetTensorMutableData<float>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
}
private:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
Ort::CustomOpApi ort_;
};
struct KernelTwo {
KernelTwo(OrtApi api)
: api_(api),
ort_(api_) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
int32_t* out = ort_.GetTensorMutableData<int32_t>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = (int32_t)(round(X[i]));
}
}
private:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
Ort::CustomOpApi ort_;
};
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
return new KernelOne(api);
};
const char* GetName() const { return "CustomOpOne"; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
} c_CustomOpOne;
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
return new KernelTwo(api);
};
const char* GetName() const { return "CustomOpTwo"; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
} c_CustomOpTwo;
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtCustomOpDomain* domain = nullptr;
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
return status;
}
AddOrtCustomOpDomainToContainer(domain, ortApi);
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpOne)) {
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpTwo)) {
return status;
}
return ortApi->AddCustomOpDomain(options, domain);
}