From 19139caf6cf05d6322daf34e760decf5be635870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Sun, 15 Dec 2024 22:09:33 +0100 Subject: [PATCH] more debug --- .../quantization/firefox_matmul_integer.cc | 319 +++++++----------- 1 file changed, 113 insertions(+), 206 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 1a2e9b48e8ef5..3a3b9612f69b4 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include #include +#include #include "firefox_matmul_integer.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" @@ -41,6 +42,68 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), FirefoxMatMulInteger8); + +std::vector MatMulFull(const uint8_t* a_data, const int8_t* b_data, + size_t M, size_t K, size_t N, + int8_t a_offset, const uint8_t* b_offset_ptr) { + std::vector output(M * N, 0); + + for (size_t row_idx = 0; row_idx < M; ++row_idx) { + const uint8_t* a_row = a_data + row_idx * K; // Start of row in A + for (size_t col_idx = 0; col_idx < N; ++col_idx) { + int64_t temp_result = 0; // Use int64_t for intermediate accumulation + + for (size_t k = 0; k < K; ++k) { + // Row-major access + uint8_t a_value = a_row[k]; + int8_t b_value = b_data[k * N + col_idx]; + + // Adjust for zero-point offsets + int32_t adjusted_a = static_cast(a_value) - static_cast(a_offset); + int32_t adjusted_b = static_cast(b_value) - static_cast(b_offset_ptr[col_idx]); + + // Accumulate product + temp_result += static_cast(adjusted_a) * static_cast(adjusted_b); + } + + + int64_t index = row_idx * N + col_idx; + if (index < 10) { + std::cout << " Result for index " << index <<" " << temp_result << "\n"; + } + // Convert to uint32_t, allowing wraparound for negative values + output[row_idx * N + col_idx] = static_cast(temp_result); + } + } + + return output; +} + +void DisplayMatrixSample(const uint32_t* matrix, size_t rows, size_t cols, const std::string& name) { + std::cout << "Sample of " << name << ":\n"; + size_t sample_rows = std::min(rows, static_cast(5)); + size_t sample_cols = std::min(cols, static_cast(5)); + + for (size_t i = 0; i < sample_rows; ++i) { + for (size_t j = 0; j < sample_cols; ++j) { + std::cout << matrix[i * cols + j] << " "; + } + std::cout << "\n"; + } +} + +void CompareMatrices(const uint32_t* matrix1, const uint32_t* matrix2, size_t rows, size_t cols, const std::string& matrix1_name, const std::string& matrix2_name) { + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + if (matrix1[i * cols + j] != matrix2[i * cols + j]) { + throw std::runtime_error( + "Mismatch between " + matrix1_name + " and " + + matrix2_name + " at row " + std::to_string(i) + ", col " + std::to_string(j)); + } + } + } +} + Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { const auto* a = ctx->Input(IN_A); const auto* b = packed_b_ ? nullptr : ctx->Input(IN_B); @@ -77,6 +140,10 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { b_is_signed = b_is_signed_; } + size_t M = static_cast(helper.M()); + size_t K = static_cast(helper.K()); + size_t N = static_cast(helper.N()); + Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); if (y->Shape().Size() == 0) { return Status::OK(); @@ -85,9 +152,9 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { auto* y_data = y->MutableData(); MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; - gemm_shape.M = static_cast(helper.M()); - gemm_shape.N = static_cast(helper.N()); - gemm_shape.K = static_cast(helper.K()); + gemm_shape.M = M; + gemm_shape.N = N; + gemm_shape.K = K; gemm_shape.AIsSigned = a->IsDataType(); gemm_shape.BIsSigned = b_is_signed; @@ -109,209 +176,49 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { gemm_params.C = y_data + helper.OutputOffsets()[batch]; } - std::vector gemmology_output(helper.M() * helper.N(), 0); - - #ifdef __EMSCRIPTEN__ - uint8_t zero_point_b = *(b_offset_ptr + helper.RightZeroPointOffsets()[0]); - - std::cout << "A Zero point: " << static_cast(a_offset) << "\n"; - std::cout << "B zero_point: " << static_cast(zero_point_b) << "\n"; - std::cout << "rows A: " << helper.M() << ", width: " << helper.K() << ", Cols B: " << helper.N() << "\n"; - std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n"; - std::cout << "B is signed: " << (b_is_signed ? "true" : "false") << "\n"; - - -std::cout << "Zero Points Debug:\n"; -std::cout << "A Zero Point: " << static_cast(a_offset) << "\n"; -std::cout << "B Zero Points (all columns): "; -for (size_t i = 0; i < static_cast(helper.N()); ++i) { - std::cout << static_cast(b_offset_ptr[i]) << " "; -} -std::cout << "\n"; - -std::cout << "Matrix Dimensions:\n"; -std::cout << "M (rows A): " << gemm_shape.M << ", K (width): " << gemm_shape.K - << ", N (cols B): " << gemm_shape.N << "\n"; - -std::cout << "Signedness:\n"; -std::cout << "AIsSigned: " << (gemm_shape.AIsSigned ? "true" : "false") << "\n"; -std::cout << "BIsSigned: " << (gemm_shape.BIsSigned ? "true" : "false") << "\n"; - - -std::cout << "Matrix A (sample):\n"; -for (size_t i = 0; i < 5; ++i) { - for (size_t j = 0; j < 5; ++j) { - std::cout << static_cast(a_data[i * helper.K() + j]) << " "; - } - std::cout << "\n"; -} - -std::cout << "Matrix B (sample):\n"; -for (size_t i = 0; i < 5; ++i) { - for (size_t j = 0; j < 5; ++j) { - std::cout << static_cast(b_data[i * helper.N() + j]) << " "; - } - std::cout << "\n"; -} -std::cout << "Offsets Debug:\n"; -std::cout << "Left Offsets (A): "; -for (size_t i = 0; i < batch_size; ++i) { - std::cout << helper.LeftOffsets()[i] << " "; -} -std::cout << "\n"; - -std::cout << "Right Offsets (B): "; -for (size_t i = 0; i < batch_size; ++i) { - std::cout << helper.RightOffsets()[i] << " "; -} -std::cout << "\n"; - -std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n"; - - -// Manually compute the first value of the first row of the output -uint32_t manual_result = 0; - -std::cout << "Dimensions: M = " << helper.M() << ", K = " << helper.K() << ", N = " << helper.N() << "\n"; - - std::cout << "Manually computing first value of the output matrix (Row 0, Col 0):\n"; - - int64_t temp_result = 0; // Use a signed type for accumulation to handle potential negatives - for (size_t k = 0; k < static_cast(helper.K()); ++k) { - uint8_t a_value = static_cast(a_data[k]); // First row of A (unsigned) - int8_t b_value = static_cast(b_data[k * helper.N()]); // First column of B (signed) - - // Adjust for zero points - int32_t adjusted_a = static_cast(a_value) - static_cast(a_offset); // A is unsigned - int32_t adjusted_b = static_cast(b_value) - static_cast(b_offset_ptr[0]); // B is signed - - // Accumulate the signed result - temp_result += static_cast(adjusted_a) * static_cast(adjusted_b); - - // Debugging individual terms - std::cout << "k = " << k - << ", A[k] = " << static_cast(a_value) - << ", B[k, 0] = " << static_cast(b_value) - << ", Adjusted A[k] = " << adjusted_a - << ", Adjusted B[k, 0] = " << adjusted_b - << ", Partial Sum (signed) = " << temp_result << "\n"; - } - - // Ensure the result fits in uint32_t, saturating if necessary - manual_result = static_cast(std::max(0, temp_result)); // Clamp to 0 for unsigned range - - std::cout << "Manual computation result (Row 0, Col 0): " << manual_result << "\n"; - - - // Gemmology call - std::cout << "Calling gemmology from onnx:\n"; - auto start_gemmology = Clock::now(); - - int8Multiply( - reinterpret_cast(a_data), - 0, // a_offset, - reinterpret_cast(b_data), - b_offset_ptr[0], - static_cast(helper.M()), // rows A - static_cast(helper.K()), // width - static_cast(helper.N()), // col B - reinterpret_cast(gemmology_output.data())); - - auto end_gemmology = Clock::now(); - auto gemmology_time = std::chrono::duration_cast(end_gemmology - start_gemmology).count(); - std::cout << "gemmology call complete.\n"; - - std::cout << "Call done\n"; - - std::cout << "Manually Clamping\n"; - - for (size_t i = 0; i < static_cast(helper.M()); ++i) { - for (size_t j = 0; j < static_cast(helper.N()); ++j) { - size_t index = i * static_cast(helper.N()) + j; - - // Interpret unsigned value as signed - uint32_t raw_value = gemmology_output[index]; - //std::cout << "Index (" << i << ", " << j << "), Original Value (unsigned): " << raw_value << "\n"; - - int32_t signed_value = static_cast(raw_value); - //std::cout << "Index (" << i << ", " << j << "), Interpreted as Signed: " << signed_value << "\n"; - - - // Clamp to non-negative - uint32_t clamped_value = static_cast(std::max(0, signed_value)); - - // Write clamped value back to output - gemmology_output[index] = clamped_value; - - // Log for debugging - if (i == 0 && j == 0) { // Only log the first value - std::cout << "Post-process Clamping for Index (0, 0):\n"; - std::cout << "Raw Value (unsigned): " << raw_value << "\n"; - std::cout << "Interpreted as Signed: " << signed_value << "\n"; - std::cout << "Clamped Value: " << clamped_value << "\n"; - } - } - - -} - - - #endif -std::cout << "Calling MlasGemmBatch\n"; - -auto start_mblas = Clock::now(); - - // Original MatmulInteger call -MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); -auto end_mblas = Clock::now(); -auto mblas_time = std::chrono::duration_cast(end_mblas - start_mblas).count(); - - -std::cout << "Calling MlasGemmBatch done\n"; -// Compute percentage difference -double percentage_diff = (static_cast(gemmology_time - mblas_time) / mblas_time) * 100.0; - -// Display the results -std::cout << "Execution Times (Microseconds): MBlas = " << mblas_time - << ", Gemmology = " << gemmology_time - << ", Difference = " << percentage_diff << "%\n"; - - - - - // Compare the outputs - std::cout << "Comparing Outputs:\n"; - //for (size_t i = 0; i < static_cast(helper.M()); ++i) { - for (size_t i = 0; i < 2; ++i) { - //for (size_t j = 0; j < static_cast(helper.N()); ++j) { - for (size_t j = 0; j < 2; ++j) { - std::cout << "Gemmology:"; - std::cout << static_cast(gemmology_output[i * helper.N() + j]) << "\n"; - std::cout << "MBLas:"; - std::cout << static_cast(y_data[i * helper.N() + j]) << "\n"; - } - std::cout << "\n"; - } -std::cout << "Comparing\n"; - - -for (size_t i = 0; i < static_cast(helper.M()); ++i) { - for (size_t j = 0; j < static_cast(helper.N()); ++j) { - std::cout << "Mismatch lookup\n"; - - - size_t index = i * helper.N() + j; - std::cout << "Lookup at Row " << i << ", Col " << j << ": " << index << "\n"; - - if (gemmology_output[index] != static_cast(y_data[index])) { - std::cout << "Mismatch"; - - ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", gemmology_output[index], - ", MlasGemmBatch = ", static_cast(y_data[index])); - } - } -} - + std::vector gemmology_output(helper.M() * helper.N(), 0); + + // Manual MatMul + auto start_matmul = Clock::now(); + std::vector matmul_output = MatMulFull(a_data, reinterpret_cast(b_data), M, K, N, a_offset, b_offset_ptr); + auto end_matmul = Clock::now(); + auto matmul_time = std::chrono::duration_cast(end_matmul - start_matmul).count(); + + // Gemmology + auto start_gemmology = Clock::now(); + int8Multiply( + reinterpret_cast(a_data), + a_offset, + reinterpret_cast(b_data), + b_offset_ptr[0], + M, + N, + K, + reinterpret_cast(gemmology_output.data())); + + auto end_gemmology = Clock::now(); + auto gemmology_time = std::chrono::duration_cast(end_gemmology - start_gemmology).count(); + + // Mlas + auto start_mblas = Clock::now(); + MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); + auto end_mblas = Clock::now(); + auto mblas_time = std::chrono::duration_cast(end_mblas - start_mblas).count(); + + // Display samples + DisplayMatrixSample(matmul_output.data(), M, N, "MatMulFull Output"); + DisplayMatrixSample(gemmology_output.data(), M, N, "gemmology Output"); + DisplayMatrixSample(reinterpret_cast(y_data), M, N, "MLas Output"); + + // make sure the three implementations return the same data + CompareMatrices(matmul_output.data(), reinterpret_cast(y_data), M, N, "MatMulFull", "MLas"); + CompareMatrices(matmul_output.data(), gemmology_output.data(), M, N, "MatMulFull", "gemmology"); + + // Output timing results + std::cout << "Timing (microseconds):\n"; + std::cout << "MatMulFull: " << matmul_time << "\n"; + std::cout << "Mlas: " << mblas_time << "\n"; + std::cout << "Gemmology: " << gemmology_time << "\n"; return Status::OK(); }