Skip to content

Commit

Permalink
added a firefox matmul backend
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 2, 2024
1 parent 49a80df commit 755e2d3
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FirefoxMatMulInteger);

// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
Expand Down
49 changes: 49 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "firefox_matmul_integer.h"
#include "core/providers/cpu/math/matmul_helper.h"

namespace onnxruntime {
namespace contrib {

ONNX_OPERATOR_KERNEL_EX(
FirefoxMatMulInteger8,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int16_t>()),
FirefoxMatMulInteger8<int8_t, int8_t, int16_t>);

template <>
Status FirefoxMatMulInteger8<int8_t, int8_t, int16_t>::Compute(OpKernelContext* ctx) const {
auto A = ctx->Input<Tensor>(0);
auto B = ctx->Input<Tensor>(1);
ORT_ENFORCE(A != nullptr && B != nullptr);

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(A->Shape(), B->Shape()));
Tensor* Y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0)
return Status::OK();

for (int i = 0; i < static_cast<int>(helper.OutputOffsets().size()); i++) {
EigenCastGEMM<int8_t, int8_t, int16_t>(
A->Data<int8_t>() + helper.LeftOffsets()[i],
B->Data<int8_t>() + helper.RightOffsets()[i],
Y->MutableData<int16_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
}

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace contrib {

template <typename T1, typename T2, typename T3>
class FirefoxMatMulInteger8 final : public OpKernel {
public:
FirefoxMatMulInteger8(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime
39 changes: 39 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,44 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-
ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1);
}));


constexpr const char* FirefoxMatMulInteger_doc = R"DOC(
Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
)DOC";



ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger, 1,
OpSchema()
.SetDoc(FirefoxMatMulInteger_doc)
.Input(0, "A", "N-dimensional matrix A", "T1")
.Input(1, "B", "N-dimensional matrix B", "T2")
.Output(0, "Y", "Matrix multiply results from A * B", "T3")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data types as 8-bit integer tensor")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data types as 8-bit integer tensor")
.TypeConstraint("T3",
{"tensor(int16)", "tensor(uint16)"},
"Constrain output Y data types as 32-bit integer tensor."
"T3 must be tensor(uint16) when both T1 and T2 are tensor(uint8),"
"or must be tensor(int16) when either T1 or T2 is tensor(int8).")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto a_type = ctx.getInputType(0);
auto b_type = ctx.getInputType(1);
auto y_type = ctx.getOutputType(0);
if (nullptr == a_type || nullptr == b_type || nullptr == y_type ||
a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType ||
b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) {
fail_type_inference(
"inputs are expected to have tensor type and output type should not be null.");
}

// Right now we only support int16
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32);

ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1);
}));


/**
* @brief Shape inference for MatMul with right hand side matrix quantized into int4
* @param ctx
Expand Down Expand Up @@ -3780,6 +3818,7 @@ Having this op allows runtime to do operator re-ordering to reduce compute FLOPs

#endif


#ifndef _OPSCHEMA_LIB_
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -189,6 +190,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger)>());
#ifndef ORT_MINIMAL_BUILD
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4)>());
#endif
Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace test {

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_1) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {1, 1}, {15});
test.AddInput<int8_t>("T2", {1, 1}, {8});
test.AddOutput<int32_t>("T3", {1, 1}, {120}); // Result is 15 * 8
test.Run();
}

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_2) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {1, 2}, {-7, 10});
test.AddInput<int8_t>("T2", {2, 1}, {-8, -11});
test.AddOutput<int32_t>("T3", {1, 1}, {8}); // Result is (-7 * -8) + (10 * -11)
test.Run();
}

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_Empty_input) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {0, 2}, {});
test.AddInput<int8_t>("T2", {2, 1}, {-8, -11});
test.AddOutput<int32_t>("T3", {0, 1}, {}); // Empty input produces an empty output
test.Run();
}

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_3) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {3, 2}, {-7, 10, 10, -113, 22, -36});
test.AddInput<int8_t>("T2", {2, 4}, {-8, -11, 13, 14, -9, 12, 3, -6});
test.AddOutput<int32_t>("T3", {3, 4},
{-158, 97, -61, -2, // First row results
989, -1426, 1693, 1682, // Second row results
282, -518, 280, -372}); // Third row results
test.Run();
}

} // namespace test
} // namespace onnxruntime

0 comments on commit 755e2d3

Please sign in to comment.