diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 810d3bbba1879..773e5ecb1209b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -1,12 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include - -#ifndef __EMSCRIPTEN__ -#include "gemmology.h" -#endif - +#include #include "firefox_matmul_integer.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" @@ -42,27 +37,10 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), FirefoxMatMulInteger8); - - -/** Typical Call - -Input Tensor A shape: {1,171,1024} -Input Tensor B shape: {1024,1024} -A Zero Point shape: {} -A Zero Point value: 123 -B Zero Point shape: {1024} -B Zero Point is per-column: 1 -Computing helper with A and B shapes. -Output Tensor Y shape: {1,171,1024} -GEMM Shape - M: 171, N: 1024, K: 1024, AIsSigned: 0, BIsSigned: 1 -Batch size: 1 - -*/ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { const auto* a = ctx->Input(IN_A); const auto* b = packed_b_ ? nullptr : ctx->Input(IN_B); - // Validate zero points uint8_t a_offset = 0; const auto* a_zero_point = ctx->Input(IN_A_ZERO_POINT); if (a_zero_point != nullptr) { @@ -71,35 +49,28 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { a_offset = *(static_cast(a_zero_point->DataRaw())); } + uint8_t b_default_offset = 0; const auto* b_zero_point = ctx->Input(IN_B_ZERO_POINT); - - #ifndef __EMSCRIPTEN__ bool b_is_signed; const uint8_t* b_offset_ptr = &b_default_offset; bool is_b_zp_per_column = false; - uint8_t b_default_offset = 0; if (b_zero_point != nullptr) { ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_), "MatmulInteger : B zero point is not valid"); is_b_zp_per_column = !IsScalarOr1ElementVector(b_zero_point); b_offset_ptr = static_cast(b_zero_point->DataRaw()); } - #endif MatMulComputeHelper helper; const uint8_t* b_data; if (nullptr != b) { ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape(), nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); b_data = static_cast(b->DataRaw()); - #ifndef __EMSCRIPTEN__ b_is_signed = b->IsDataType(); - #endif } else { ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_, nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); b_data = static_cast(packed_b_.get()); - #ifndef __EMSCRIPTEN__ b_is_signed = b_is_signed_; - #endif } Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); @@ -109,71 +80,55 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { const uint8_t* a_data = static_cast(a->DataRaw()); auto* y_data = y->MutableData(); - #ifdef __EMSCRIPTEN__ - // Prepare output buffer - std::vector float_output(helper.M() * helper.N(), 0.0f); + MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; + gemm_shape.M = static_cast(helper.M()); + gemm_shape.N = static_cast(helper.N()); + gemm_shape.K = static_cast(helper.K()); + gemm_shape.AIsSigned = a->IsDataType(); + gemm_shape.BIsSigned = b_is_signed; + + const size_t batch_size = helper.OutputOffsets().size(); + + std::vector gemm_data_vec(batch_size); + + for (size_t batch = 0; batch < batch_size; batch++) { + auto& gemm_params = gemm_data_vec[batch]; + gemm_params.lda = gemm_shape.K; + gemm_params.ZeroPointA = a_offset; + gemm_params.ldb = gemm_shape.N; + gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch]; + gemm_params.PerColumnZeroPoints = is_b_zp_per_column; + gemm_params.ldc = gemm_shape.N; + gemm_params.BIsPacked = bool(packed_b_); + gemm_params.A = a_data + helper.LeftOffsets()[batch]; + gemm_params.B = b_data + helper.RightOffsets()[batch]; + gemm_params.C = y_data + helper.OutputOffsets()[batch]; + } - // Call the function - // matrix A (M x K) * matrix B (K x N) - // matrix C (M x N) - size_t rows_a = static_cast(helper.M()); - size_t cols_b = static_cast(helper.N()); - size_t width = static_cast(helper.K()); + std::vector int32_output(helper.M() * helper.N(), 0); - // gemmology is only doing A unsigned x B signed + #ifdef __EMSCRIPTEN__ int8Multiply(reinterpret_cast(a_data), - a_offset, - reinterpret_cast(b_data), - 0, // b_zero_point - rows_a, // rows A - width, // width - cols_b, // col B - reinterpret_cast(y_data)); - - - // Print the output - #if 0 - std::cout << "Output matrix:\n"; - for (Index i = 0; i < rows_a; ++i) { - for (Index j = 0; j < cols_b; ++j) { - std::cout << y_data[i * cols_b + j] << " "; - } - std::cout << "\n"; - } + a_offset, + reinterpret_cast(b_data), + 0, // b_zero_point + static_cast(helper.M()), // rows A + static_cast(helper.K()), // width + static_cast(helper.N()), // col B + reinterpret_cast(int32_output.data())); #endif - #else - // XXX original call - MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; - gemm_shape.M = static_cast(helper.M()); - gemm_shape.N = static_cast(helper.N()); - gemm_shape.K = static_cast(helper.K()); - gemm_shape.AIsSigned = a->IsDataType(); - gemm_shape.BIsSigned = b_is_signed; - - const size_t batch_size = helper.OutputOffsets().size(); - - std::vector gemm_data_vec(batch_size); - - for (size_t batch = 0; batch < batch_size; batch++) { - auto& gemm_params = gemm_data_vec[batch]; - gemm_params.lda = gemm_shape.K; - gemm_params.ZeroPointA = a_offset; - gemm_params.ldb = gemm_shape.N; - gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch]; - gemm_params.PerColumnZeroPoints = is_b_zp_per_column; - gemm_params.ldc = gemm_shape.N; - gemm_params.BIsPacked = bool(packed_b_); - gemm_params.A = a_data + helper.LeftOffsets()[batch]; - gemm_params.B = b_data + helper.RightOffsets()[batch]; - gemm_params.C = y_data + helper.OutputOffsets()[batch]; - } - - MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); - #endif + MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); + + // Compare the outputs + std::cout << "Comparing Outputs:\n"; + for (size_t i = 0; i < int32_output.size(); ++i) { + std::cout << "Index " << i << ": int8Multiply = " << int32_output[i] + << ", MlasGemmBatch = " << static_cast(y_data[i]) << "\n"; + } + return Status::OK(); } - } // namespace contrib } // namespace onnxruntime