From 1253aadac7318f80437f516c93269f0bcca58a4d Mon Sep 17 00:00:00 2001 From: mbertuletti Date: Mon, 8 Jan 2024 11:38:24 +0100 Subject: [PATCH] [software] Add complex instructions to fp16 cmatmul --- software/apps/cmatmul_f16/main.c | 43 +- software/apps/ofdm/main.c | 4 +- software/runtime/data/data_cmatmul_f16.h.tpl | 2 +- software/runtime/kernel/cmatmul_f16.h | 217 --------- software/runtime/kernel/mempool_checks.h | 12 +- software/runtime/kernel/mempool_cmatmul_f16.h | 433 +++++++----------- 6 files changed, 203 insertions(+), 508 deletions(-) diff --git a/software/apps/cmatmul_f16/main.c b/software/apps/cmatmul_f16/main.c index 1d999a0d8..3a0c2c414 100644 --- a/software/apps/cmatmul_f16/main.c +++ b/software/apps/cmatmul_f16/main.c @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -// Author: Samuel Riedel, ETH Zurich +// Author: Marco Bertuletti, ETH Zurich #include #include @@ -15,7 +15,9 @@ #include "data/data_cmatmul_f16.h" #include "kernel/mempool_checks.h" #include "kernel/mempool_cmatmul_f16.h" -#define PARALLEL_2x2 +#define PARALLEL +#define LOOP_2x4 +// #define TEST __fp16 matrix_a[2 * dim_M * dim_N] __attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)), @@ -38,8 +40,8 @@ int main() { // Initialize Matrices if (core_id == 0) { - dma_memcpy_blocking(matrix_a, A, dim_M * dim_N * sizeof(int32_t)); - dma_memcpy_blocking(matrix_b, B, dim_N * dim_P * sizeof(int32_t)); + dma_memcpy_blocking(matrix_a, A, 2 * dim_M * dim_N * sizeof(int16_t)); + dma_memcpy_blocking(matrix_b, B, 2 * dim_N * dim_P * sizeof(int16_t)); } // Wait at barrier until everyone is ready mempool_barrier(num_cores); @@ -48,40 +50,43 @@ int main() { // Execute function to test. if (core_id == 0) { mempool_start_benchmark(); - cmatmul_2x4_f16s(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P); + cmatmul_2x2_f16s(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P); mempool_stop_benchmark(); } mempool_barrier(num_cores); #endif -#if defined(PARALLEL_2x2) +#if defined(PARALLEL) // Execute function to test. - uint32_t nPE = core_id < (dim_P / 2) ? num_cores : (dim_P / 2); + +#if defined(LOOP_2x2) + uint32_t nPE = num_cores < (dim_P / 2) ? num_cores : (dim_P / 2); // 2x2 if (core_id < nPE) { mempool_start_benchmark(); - cmatmul_2x2_f16p(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P, core_id, - nPE); - mempool_log_partial_barrier(2, core_id, nPE); + cmatmul_2x2_f16p((v2h *)matrix_a, (v2h *)matrix_b, (v2h *)matrix_c, dim_M, + dim_N, dim_P, core_id, nPE); mempool_stop_benchmark(); } - mempool_barrier(num_cores); #endif - -#if defined(PARALLEL_2x4) - // Execute function to test. - uint32_t nPE = core_id < (dim_P / 4) ? num_cores : (dim_P / 4); +#if defined(LOOP_2x4) + uint32_t nPE = num_cores < (dim_P / 4) ? num_cores : (dim_P / 4); // 2x4 if (core_id < nPE) { mempool_start_benchmark(); - cmatmul_2x4_f16p(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P, core_id, - nPE); - mempool_log_partial_barrier(2, core_id, nPE); + // cmatmul_2x4_f16p( + // (v2h*) matrix_a, + // (v2h*) matrix_b, + // (v2h*) matrix_c, + // dim_M, dim_N, dim_P, core_id, nPE); + cmatmul_2x4_folded_f16p(matrix_a, matrix_b, matrix_a_folded, matrix_c, + dim_M, dim_N, dim_P, core_id, nPE); mempool_stop_benchmark(); } +#endif mempool_barrier(num_cores); #endif #if defined(TEST) - mempool_check_f32(matrix_c, C, 2 * dim_M * dim_P, 0.01f, 0); + mempool_check_f16(matrix_c, C, 2 * dim_M * dim_P, 0.1f, 0); mempool_barrier(num_cores); #endif diff --git a/software/apps/ofdm/main.c b/software/apps/ofdm/main.c index 1808669c6..8408c1035 100644 --- a/software/apps/ofdm/main.c +++ b/software/apps/ofdm/main.c @@ -101,8 +101,8 @@ int main() { /* BEAMFORMING */ mempool_start_benchmark(); - cmatmul_2x4_folded_f16p(l1_pBF_Coef_folded, l1_pFFT_Src, l1_pFFT_Dst, N_BEAMS, - N_RX, N_SC, core_id, num_cores); + cmatmul_2x4_folded_f16p(l1_pBF_Coef_folded, l1_pBF_Coef_folded, l1_pFFT_Src, + l1_pFFT_Dst, N_BEAMS, N_RX, N_SC, core_id, num_cores); mempool_stop_benchmark(); dump_prova(2); diff --git a/software/runtime/data/data_cmatmul_f16.h.tpl b/software/runtime/data/data_cmatmul_f16.h.tpl index bc67d0f45..c6f1519b9 100644 --- a/software/runtime/data/data_cmatmul_f16.h.tpl +++ b/software/runtime/data/data_cmatmul_f16.h.tpl @@ -7,7 +7,7 @@ i = 0 out += '\n' for a in array: - out += '(__fp16){:.5f}, '.format(a) + out += '(__fp16){:.4f}, '.format(a) i += 1 if i % 8 == 0: out += '\n' diff --git a/software/runtime/kernel/cmatmul_f16.h b/software/runtime/kernel/cmatmul_f16.h index c80473679..e69de29bb 100644 --- a/software/runtime/kernel/cmatmul_f16.h +++ b/software/runtime/kernel/cmatmul_f16.h @@ -1,217 +0,0 @@ -// Copyright 2021 ETH Zurich and University of Bologna. -// Licensed under the Apache License, Version 2.0, see LICENSE for details. -// SPDX-License-Identifier: Apache-2.0 - -// Author: Marco Bertuletti, ETH Zurich - -/* This library implements the complex matrix multiplication in multiple - * different ways. The functions all follow the following format: - * - * A is an M x N matrix, B is a N x P matrix, and C is a M x P matrix - * C = AB - */ - -#include "xpulp/builtins_v2.h" - -#define CMATMUL_2x4_LOOP \ - float sum00_real = 0.0f; \ - float sum01_real = 0.0f; \ - float sum02_real = 0.0f; \ - float sum03_real = 0.0f; \ - float sum10_real = 0.0f; \ - float sum11_real = 0.0f; \ - float sum12_real = 0.0f; \ - float sum13_real = 0.0f; \ - float sum00_imag = 0.0f; \ - float sum01_imag = 0.0f; \ - float sum02_imag = 0.0f; \ - float sum03_imag = 0.0f; \ - float sum10_imag = 0.0f; \ - float sum11_imag = 0.0f; \ - float sum12_imag = 0.0f; \ - float sum13_imag = 0.0f; \ - for (j = 0; j < N; j += 2) { \ - v2h a00 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 0))]); \ - v2h a01 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 1))]); \ - v2h a10 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 0))]); \ - v2h a11 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 1))]); \ - v2h b00 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 0))]); \ - v2h b01 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 1))]); \ - v2h b02 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 2))]); \ - v2h b03 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 3))]); \ - v2h b10 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 0))]); \ - v2h b11 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 1))]); \ - v2h b12 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 2))]); \ - v2h b13 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 3))]); \ - v2h a00s, a01s, a10s, a11s; \ - asm volatile("pv.shuffle2.h %[a00s], %[a00], %[mask];" \ - "pv.shuffle2.h %[a01s], %[a01], %[mask];" \ - "pv.shuffle2.h %[a10s], %[a10], %[mask];" \ - "pv.shuffle2.h %[a11s], %[a11], %[mask];" \ - : [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \ - [a11s] "+r"(a11s) \ - : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \ - [a11] "r"(a11), [mask] "r"(0x00020003) \ - :); \ - asm volatile( \ - "vfdotpex.s.h %[sum00_imag], %[a00s], %[b00];" \ - "vfdotpex.s.h %[sum10_imag], %[a10s], %[b00];" \ - "vfdotpex.s.h %[sum01_imag], %[a00s], %[b01];" \ - "vfdotpex.s.h %[sum11_imag], %[a10s], %[b01];" \ - "vfdotpex.s.h %[sum02_imag], %[a00s], %[b02];" \ - "vfdotpex.s.h %[sum12_imag], %[a10s], %[b02];" \ - "vfdotpex.s.h %[sum03_imag], %[a00s], %[b03];" \ - "vfdotpex.s.h %[sum13_imag], %[a10s], %[b03];" \ - "vfdotpex.s.h %[sum00_imag], %[a01s], %[b10];" \ - "vfdotpex.s.h %[sum10_imag], %[a11s], %[b10];" \ - "vfdotpex.s.h %[sum01_imag], %[a01s], %[b11];" \ - "vfdotpex.s.h %[sum11_imag], %[a11s], %[b11];" \ - "vfdotpex.s.h %[sum02_imag], %[a01s], %[b12];" \ - "vfdotpex.s.h %[sum12_imag], %[a11s], %[b12];" \ - "vfdotpex.s.h %[sum03_imag], %[a01s], %[b13];" \ - "vfdotpex.s.h %[sum13_imag], %[a11s], %[b13];" \ - : [sum00_imag] "+&r"(sum00_imag), [sum01_imag] "+&r"(sum01_imag), \ - [sum02_imag] "+&r"(sum02_imag), [sum03_imag] "+&r"(sum03_imag), \ - [sum10_imag] "+&r"(sum10_imag), [sum11_imag] "+&r"(sum11_imag), \ - [sum12_imag] "+&r"(sum12_imag), [sum13_imag] "+&r"(sum13_imag) \ - : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ - [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \ - [b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \ - [b13] "r"(b13) \ - :); \ - asm volatile("xor %[a00s], %[a00], %[mask];" \ - "xor %[a01s], %[a01], %[mask];" \ - "xor %[a10s], %[a10], %[mask];" \ - "xor %[a11s], %[a11], %[mask];" \ - : [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \ - [a11s] "+r"(a11s) \ - : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \ - [a11] "r"(a11), [mask] "r"(0x80000000) \ - :); \ - asm volatile( \ - "vfdotpex.s.h %[sum00_real], %[a00s], %[b00];" \ - "vfdotpex.s.h %[sum10_real], %[a10s], %[b00];" \ - "vfdotpex.s.h %[sum01_real], %[a00s], %[b01];" \ - "vfdotpex.s.h %[sum11_real], %[a10s], %[b01];" \ - "vfdotpex.s.h %[sum02_real], %[a00s], %[b02];" \ - "vfdotpex.s.h %[sum12_real], %[a10s], %[b02];" \ - "vfdotpex.s.h %[sum03_real], %[a00s], %[b03];" \ - "vfdotpex.s.h %[sum13_real], %[a10s], %[b03];" \ - "vfdotpex.s.h %[sum00_real], %[a01s], %[b10];" \ - "vfdotpex.s.h %[sum10_real], %[a11s], %[b10];" \ - "vfdotpex.s.h %[sum01_real], %[a01s], %[b11];" \ - "vfdotpex.s.h %[sum11_real], %[a11s], %[b11];" \ - "vfdotpex.s.h %[sum02_real], %[a01s], %[b12];" \ - "vfdotpex.s.h %[sum12_real], %[a11s], %[b12];" \ - "vfdotpex.s.h %[sum03_real], %[a01s], %[b13];" \ - "vfdotpex.s.h %[sum13_real], %[a11s], %[b13];" \ - : [sum00_real] "+&r"(sum00_real), [sum01_real] "+&r"(sum01_real), \ - [sum02_real] "+&r"(sum02_real), [sum03_real] "+&r"(sum03_real), \ - [sum10_real] "+&r"(sum10_real), [sum11_real] "+&r"(sum11_real), \ - [sum12_real] "+&r"(sum12_real), [sum13_real] "+&r"(sum13_real) \ - : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ - [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \ - [b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \ - [b13] "r"(b13) \ - :); \ - } \ - v2h res00, res01, res02, res03; \ - v2h res10, res11, res12, res13; \ - asm volatile( \ - "vfcpka.h.s %[res00], %[sum00_real], %[sum00_imag];" \ - "vfcpka.h.s %[res01], %[sum01_real], %[sum01_imag];" \ - "vfcpka.h.s %[res02], %[sum02_real], %[sum02_imag];" \ - "vfcpka.h.s %[res03], %[sum03_real], %[sum03_imag];" \ - "vfcpka.h.s %[res10], %[sum10_real], %[sum10_imag];" \ - "vfcpka.h.s %[res11], %[sum11_real], %[sum11_imag];" \ - "vfcpka.h.s %[res12], %[sum12_real], %[sum12_imag];" \ - "vfcpka.h.s %[res13], %[sum13_real], %[sum13_imag];" \ - : [res00] "+r"(res00), [res01] "+r"(res01), [res02] "+r"(res02), \ - [res03] "+r"(res03), [res10] "+r"(res10), [res11] "+r"(res11), \ - [res12] "+r"(res12), [res13] "+r"(res13) \ - : [sum00_imag] "r"(sum00_imag), [sum01_imag] "r"(sum01_imag), \ - [sum02_imag] "r"(sum02_imag), [sum03_imag] "r"(sum03_imag), \ - [sum10_imag] "r"(sum10_imag), [sum11_imag] "r"(sum11_imag), \ - [sum12_imag] "r"(sum12_imag), [sum13_imag] "r"(sum13_imag), \ - [sum00_real] "r"(sum00_real), [sum01_real] "r"(sum01_real), \ - [sum02_real] "r"(sum02_real), [sum03_real] "r"(sum03_real), \ - [sum10_real] "r"(sum10_real), [sum11_real] "r"(sum11_real), \ - [sum12_real] "r"(sum12_real), [sum13_real] "r"(sum13_real) \ - :); \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 0)]) = res00; \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 1)]) = res01; \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 2)]) = res02; \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 3)]) = res03; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 0)]) = res10; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 1)]) = res11; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 2)]) = res12; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 3)]) = res13; - -void cmatmul_f16s(__fp16 const *__restrict__ A, __fp16 const *__restrict__ B, - __fp16 *__restrict__ C, uint32_t M, uint32_t N, uint32_t P) { - uint32_t i = 0; // loop counter for M - uint32_t j = 0; // loop counter for N - uint32_t k = 0; // loop counter for P - for (k = 0; k < P; k++) { - for (i = 0; i < M; i++) { - float sum_real = 0.0f; - float sum_imag = 0.0f; - for (j = 0; j < N; j++) { - v2h a = (*(v2h *)&A[2 * (i * N + j)]); - v2h b = (*(v2h *)&B[2 * (j * P + k)]); - v2h as; - asm volatile("pv.shuffle2.h %[as], %[a], %[mask_shuffle];" - "vfdotpex.s.h %[sum_imag], %[as], %[b];" - "xor %[as], %[a], %[mask_sign];" - "vfdotpex.s.h %[sum_real], %[as], %[b];" - : [sum_real] "+&r"(sum_real), [sum_imag] "+&r"(sum_imag), - [as] "+&r"(as) - : [a] "r"(a), [b] "r"(b), [mask_shuffle] "r"(0x00020003), - [mask_sign] "r"(0x80000000) - :); - } - v2h res; - asm volatile("vfcpka.h.s %[res], %[sum_real], %[sum_imag];" - : [res] "+r"(res) - : [sum_real] "r"(sum_real), [sum_imag] "r"(sum_imag) - :); - (*(v2h *)&C[2 * (i * P + k)]) = res; - } - } - return; -} - -void cmatmul_2x4_f16s(__fp16 const *__restrict__ A, - __fp16 const *__restrict__ B, __fp16 *__restrict__ C, - uint32_t M, uint32_t N, uint32_t P) { - uint32_t i = 0; // loop counter for M - uint32_t j = 0; // loop counter for N - uint32_t k = 0; // loop counter for P - for (k = 0; k < P; k += 4) { - for (i = 0; i < M; i += 2) { - CMATMUL_2x4_LOOP; - } - } - return; -} - -void cmatmul_2x4_f16p(__fp16 const *__restrict__ A, - __fp16 const *__restrict__ B, __fp16 *__restrict__ C, - uint32_t M, uint32_t N, uint32_t P, uint32_t core_id, - uint32_t numThreads) { - - uint32_t i = 0; // loop counter for M - uint32_t j = 0; // loop counter for N - uint32_t k = 0; // loop counter for P - - uint32_t shift_id = core_id % (M / 2); - for (k = core_id * 4; k < P; k += 4 * numThreads) { - for (i = 0; i < shift_id; i += 2) { - CMATMUL_2x4_LOOP; - } - for (i = shift_id; i < M; i += 2) { - CMATMUL_2x4_LOOP; - } - } - return; -} diff --git a/software/runtime/kernel/mempool_checks.h b/software/runtime/kernel/mempool_checks.h index aa8333bcd..46876472e 100644 --- a/software/runtime/kernel/mempool_checks.h +++ b/software/runtime/kernel/mempool_checks.h @@ -22,7 +22,7 @@ void mempool_check_q32(int32_t *__restrict__ pRes, int32_t *__restrict__ pExp, int32_t exp = pExp[i]; int32_t res = pRes[i]; error = exp - res; - bool print = ((error > TOL) || (error < (-TOL))) | verbose; + bool print = ((error > TOL) || (error < (-TOL))) || verbose; if (print) { printf("CHECK(%d): EXP = %x - RESP = %x\n", i, exp, res); ERRORS++; @@ -48,9 +48,9 @@ void mempool_check_q16(int16_t *__restrict__ pRes, int16_t *__restrict__ pExp, if (core_id == 0) { uint32_t ERRORS = 0; for (uint32_t i = 0; i < NEL; i++) { - int16_t exp = (int16_t) pExp[i]; - int16_t res = (int16_t) pRes[i]; - error = (int16_t) (exp - res); + int16_t exp = (int16_t)pExp[i]; + int16_t res = (int16_t)pRes[i]; + error = (int16_t)(exp - res); bool print = ((error > TOL) || (error < (-TOL))) | verbose; if (print) { printf("CHECK(%d): EXP = %x - RESP = %x\n", i, exp, res); @@ -84,7 +84,7 @@ void mempool_check_f32(float *__restrict__ pRes, float *__restrict__ pExp, : [error] "+&r"(error) : [res] "r"(res), [exp] "r"(exp) :); - bool print = ((error > TOL) || (error < (-TOL))) | verbose; + bool print = ((error > TOL) || (error < (-TOL))) || verbose; if (print) { printf("CHECK(%d): EXP = %x - RESP = %x\n", i, exp, res); ERRORS++; @@ -117,7 +117,7 @@ void mempool_check_f16(__fp16 *__restrict__ pRes, __fp16 *__restrict__ pExp, : [error] "+&r"(error) : [res] "r"(res), [exp] "r"(exp) :); - bool print = ((error > TOL) || (error < (-TOL))) | verbose; + bool print = ((error > TOL) || (error < (-TOL))) || verbose; if (print) { printf("CHECK(%d): EXP = %x - RESP = %x\n", i, *(int32_t *)&exp, *(int32_t *)&res); diff --git a/software/runtime/kernel/mempool_cmatmul_f16.h b/software/runtime/kernel/mempool_cmatmul_f16.h index 195af694b..64159affe 100644 --- a/software/runtime/kernel/mempool_cmatmul_f16.h +++ b/software/runtime/kernel/mempool_cmatmul_f16.h @@ -11,7 +11,32 @@ * C = AB */ +#pragma once #include "xpulp/builtins_v2.h" +#define NUM_BANKS (NUM_CORES * BANKING_FACTOR) + +#define CMATMUL_1x1_LOOP \ + float sum_real = 0.0f; \ + float sum_imag = 0.0f; \ + v2h res, as; \ + for (j = 0; j < N; j++) { \ + v2h a = *(v2h *)&A[2 * (i * N + j)]; \ + v2h b = *(v2h *)&B[2 * (j * P + k)]; \ + asm volatile( \ + "pv.shuffle2.h %[as], %[a], %[mask_shuffle];" \ + "vfdotpex.s.h %[sum_imag], %[as], %[b];" \ + "xor %[as], %[a], %[mask_sign];" \ + "vfdotpex.s.h %[sum_real], %[as], %[b];" \ + : [sum_real] "+r"(sum_real), [sum_imag] "+r"(sum_imag), [as] "+r"(as) \ + : [a] "r"(a), [b] "r"(b), [mask_shuffle] "r"(0x00020003), \ + [mask_sign] "r"(0x80000000) \ + :); \ + } \ + asm volatile("vfcpka.h.s %[res], %[sum_real], %[sum_imag];" \ + : [res] "+r"(res) \ + : [sum_real] "r"(sum_real), [sum_imag] "r"(sum_imag) \ + :); \ + (*(v2h *)&C[2 * (i * P + k)]) = res; #define CMATMUL_2x2_LOOP \ float sum00_real = 0.0f; \ @@ -22,16 +47,17 @@ float sum01_imag = 0.0f; \ float sum10_imag = 0.0f; \ float sum11_imag = 0.0f; \ + v2h a00s, a01s, a10s, a11s; \ + v2h res00, res01, res10, res11; \ for (j = 0; j < N; j += 2) { \ - v2h a00 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 0))]); \ - v2h a01 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 1))]); \ - v2h a10 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 0))]); \ - v2h a11 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 1))]); \ - v2h b00 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 0))]); \ - v2h b01 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 1))]); \ - v2h b10 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 0))]); \ - v2h b11 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 1))]); \ - v2h a00s, a01s, a10s, a11s; \ + v2h a00 = *(v2h *)&A[2 * ((i + 0) * N + (j + 0))]; \ + v2h a01 = *(v2h *)&A[2 * ((i + 0) * N + (j + 1))]; \ + v2h a10 = *(v2h *)&A[2 * ((i + 1) * N + (j + 0))]; \ + v2h a11 = *(v2h *)&A[2 * ((i + 1) * N + (j + 1))]; \ + v2h b00 = *(v2h *)&B[2 * ((j + 0) * P + (k + 0))]; \ + v2h b01 = *(v2h *)&B[2 * ((j + 0) * P + (k + 1))]; \ + v2h b10 = *(v2h *)&B[2 * ((j + 1) * P + (k + 0))]; \ + v2h b11 = *(v2h *)&B[2 * ((j + 1) * P + (k + 1))]; \ asm volatile("pv.shuffle2.h %[a00s], %[a00], %[mask];" \ "pv.shuffle2.h %[a10s], %[a10], %[mask];" \ "pv.shuffle2.h %[a01s], %[a01], %[mask];" \ @@ -50,8 +76,8 @@ "vfdotpex.s.h %[sum10_imag], %[a11s], %[b10];" \ "vfdotpex.s.h %[sum01_imag], %[a01s], %[b11];" \ "vfdotpex.s.h %[sum11_imag], %[a11s], %[b11];" \ - : [sum00_imag] "=r"(sum00_imag), [sum01_imag] "=r"(sum01_imag), \ - [sum10_imag] "=r"(sum10_imag), [sum11_imag] "=r"(sum11_imag) \ + : [sum00_imag] "+&r"(sum00_imag), [sum01_imag] "+&r"(sum01_imag), \ + [sum10_imag] "+&r"(sum10_imag), [sum11_imag] "+&r"(sum11_imag) \ : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b10] "r"(b10), \ [b11] "r"(b11) \ @@ -74,15 +100,13 @@ "vfdotpex.s.h %[sum10_real], %[a11s], %[b10];" \ "vfdotpex.s.h %[sum01_real], %[a01s], %[b11];" \ "vfdotpex.s.h %[sum11_real], %[a11s], %[b11];" \ - : [sum00_real] "=r"(sum00_real), [sum01_real] "=r"(sum01_real), \ - [sum10_real] "=r"(sum10_real), [sum11_real] "=r"(sum11_real) \ + : [sum00_real] "+&r"(sum00_real), [sum01_real] "+&r"(sum01_real), \ + [sum10_real] "+&r"(sum10_real), [sum11_real] "+&r"(sum11_real) \ : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b10] "r"(b10), \ [b11] "r"(b11) \ :); \ } \ - v2h res00, res01; \ - v2h res10, res11; \ asm volatile("vfcpka.h.s %[res00], %[sum00_real], %[sum00_imag];" \ "vfcpka.h.s %[res01], %[sum01_real], %[sum01_imag];" \ "vfcpka.h.s %[res10], %[sum10_real], %[sum10_imag];" \ @@ -116,26 +140,28 @@ float sum11_imag = 0.0f; \ float sum12_imag = 0.0f; \ float sum13_imag = 0.0f; \ + v2h a00s, a01s, a10s, a11s; \ + v2h res00, res01, res02, res03; \ + v2h res10, res11, res12, res13; \ for (j = 0; j < N; j += 2) { \ - v2h a00 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 0))]); \ - v2h a01 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 1))]); \ - v2h a10 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 0))]); \ - v2h a11 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 1))]); \ - v2h b00 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 0))]); \ - v2h b01 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 1))]); \ - v2h b02 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 2))]); \ - v2h b03 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 3))]); \ - v2h b10 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 0))]); \ - v2h b11 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 1))]); \ - v2h b12 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 2))]); \ - v2h b13 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 3))]); \ - v2h a00s, a01s, a10s, a11s; \ + v2h a00 = *(v2h *)&A_fetch[2 * ((i + 0) + (j + 0))]; \ + v2h a01 = *(v2h *)&A_fetch[2 * ((i + 0) + (j + 1))]; \ + v2h a10 = *(v2h *)&A_fetch[2 * ((i + 1) + (j + 0))]; \ + v2h a11 = *(v2h *)&A_fetch[2 * ((i + 1) + (j + 1))]; \ + v2h b00 = *(v2h *)&B[2 * ((j + 0) * P + (k + 0))]; \ + v2h b01 = *(v2h *)&B[2 * ((j + 0) * P + (k + 1))]; \ + v2h b02 = *(v2h *)&B[2 * ((j + 0) * P + (k + 2))]; \ + v2h b03 = *(v2h *)&B[2 * ((j + 0) * P + (k + 3))]; \ + v2h b10 = *(v2h *)&B[2 * ((j + 1) * P + (k + 0))]; \ + v2h b11 = *(v2h *)&B[2 * ((j + 1) * P + (k + 1))]; \ + v2h b12 = *(v2h *)&B[2 * ((j + 1) * P + (k + 2))]; \ + v2h b13 = *(v2h *)&B[2 * ((j + 1) * P + (k + 3))]; \ asm volatile("pv.shuffle2.h %[a00s], %[a00], %[mask];" \ "pv.shuffle2.h %[a10s], %[a10], %[mask];" \ "pv.shuffle2.h %[a01s], %[a01], %[mask];" \ "pv.shuffle2.h %[a11s], %[a11], %[mask];" \ - : [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \ - [a11s] "+r"(a11s) \ + : [a00s] "=r"(a00s), [a01s] "=r"(a01s), [a10s] "=r"(a10s), \ + [a11s] "=r"(a11s) \ : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \ [a11] "r"(a11), [mask] "r"(0x00020003) \ :); \ @@ -156,10 +182,10 @@ "vfdotpex.s.h %[sum12_imag], %[a11s], %[b12];" \ "vfdotpex.s.h %[sum03_imag], %[a01s], %[b13];" \ "vfdotpex.s.h %[sum13_imag], %[a11s], %[b13];" \ - : [sum00_imag] "=r"(sum00_imag), [sum01_imag] "=r"(sum01_imag), \ - [sum02_imag] "=r"(sum02_imag), [sum03_imag] "=r"(sum03_imag), \ - [sum10_imag] "=r"(sum10_imag), [sum11_imag] "=r"(sum11_imag), \ - [sum12_imag] "=r"(sum12_imag), [sum13_imag] "=r"(sum13_imag) \ + : [sum00_imag] "+&r"(sum00_imag), [sum01_imag] "+&r"(sum01_imag), \ + [sum02_imag] "+&r"(sum02_imag), [sum03_imag] "+&r"(sum03_imag), \ + [sum10_imag] "+&r"(sum10_imag), [sum11_imag] "+&r"(sum11_imag), \ + [sum12_imag] "+&r"(sum12_imag), [sum13_imag] "+&r"(sum13_imag) \ : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \ [b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \ @@ -169,8 +195,8 @@ "xor %[a10s], %[a10], %[mask];" \ "xor %[a01s], %[a01], %[mask];" \ "xor %[a11s], %[a11], %[mask];" \ - : [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \ - [a11s] "+r"(a11s) \ + : [a00s] "=r"(a00s), [a01s] "=r"(a01s), [a10s] "=r"(a10s), \ + [a11s] "=r"(a11s) \ : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \ [a11] "r"(a11), [mask] "r"(0x80000000) \ :); \ @@ -191,18 +217,16 @@ "vfdotpex.s.h %[sum12_real], %[a11s], %[b12];" \ "vfdotpex.s.h %[sum03_real], %[a01s], %[b13];" \ "vfdotpex.s.h %[sum13_real], %[a11s], %[b13];" \ - : [sum00_real] "=r"(sum00_real), [sum01_real] "=r"(sum01_real), \ - [sum02_real] "=r"(sum02_real), [sum03_real] "=r"(sum03_real), \ - [sum10_real] "=r"(sum10_real), [sum11_real] "=r"(sum11_real), \ - [sum12_real] "=r"(sum12_real), [sum13_real] "=r"(sum13_real) \ + : [sum00_real] "+&r"(sum00_real), [sum01_real] "+&r"(sum01_real), \ + [sum02_real] "+&r"(sum02_real), [sum03_real] "+&r"(sum03_real), \ + [sum10_real] "+&r"(sum10_real), [sum11_real] "+&r"(sum11_real), \ + [sum12_real] "+&r"(sum12_real), [sum13_real] "+&r"(sum13_real) \ : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \ [b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \ [b13] "r"(b13) \ :); \ } \ - v2h res00, res01, res02, res03; \ - v2h res10, res11, res12, res13; \ asm volatile( \ "vfcpka.h.s %[res00], %[sum00_real], %[sum00_imag];" \ "vfcpka.h.s %[res01], %[sum01_real], %[sum01_imag];" \ @@ -212,9 +236,9 @@ "vfcpka.h.s %[res11], %[sum11_real], %[sum11_imag];" \ "vfcpka.h.s %[res12], %[sum12_real], %[sum12_imag];" \ "vfcpka.h.s %[res13], %[sum13_real], %[sum13_imag];" \ - : [res00] "+r"(res00), [res01] "+r"(res01), [res02] "+r"(res02), \ - [res03] "+r"(res03), [res10] "+r"(res10), [res11] "+r"(res11), \ - [res12] "+r"(res12), [res13] "+r"(res13) \ + : [res00] "=r"(res00), [res01] "=r"(res01), [res02] "=r"(res02), \ + [res03] "=r"(res03), [res10] "=r"(res10), [res11] "=r"(res11), \ + [res12] "=r"(res12), [res13] "=r"(res13) \ : [sum00_imag] "r"(sum00_imag), [sum01_imag] "r"(sum01_imag), \ [sum02_imag] "r"(sum02_imag), [sum03_imag] "r"(sum03_imag), \ [sum10_imag] "r"(sum10_imag), [sum11_imag] "r"(sum11_imag), \ @@ -233,32 +257,24 @@ (*(v2h *)&C[2 * ((i + 1) * P + k + 2)]) = res12; \ (*(v2h *)&C[2 * ((i + 1) * P + k + 3)]) = res13; -#define CMATMUL_2x4_LOOP_FOLDED \ - float sum00_real = 0.0f; \ - float sum01_real = 0.0f; \ - float sum02_real = 0.0f; \ - float sum03_real = 0.0f; \ - float sum10_real = 0.0f; \ - float sum11_real = 0.0f; \ - float sum12_real = 0.0f; \ - float sum13_real = 0.0f; \ - float sum00_imag = 0.0f; \ - float sum01_imag = 0.0f; \ - float sum02_imag = 0.0f; \ - float sum03_imag = 0.0f; \ - float sum10_imag = 0.0f; \ - float sum11_imag = 0.0f; \ - float sum12_imag = 0.0f; \ - float sum13_imag = 0.0f; \ +/**************************************************************************/ +/**************************************************************************/ +// COMPLEX DOTP INSTRUCTIONS + +#define CMATMUL_CDOTP_2x4_LOOP \ + v2h sum00 = (v2h)0.0f; \ + v2h sum01 = (v2h)0.0f; \ + v2h sum02 = (v2h)0.0f; \ + v2h sum03 = (v2h)0.0f; \ + v2h sum10 = (v2h)0.0f; \ + v2h sum11 = (v2h)0.0f; \ + v2h sum12 = (v2h)0.0f; \ + v2h sum13 = (v2h)0.0f; \ for (j = 0; j < N; j += 2) { \ - v2h a00 = \ - (*(v2h *)&A_folded[2 * ((i + 0) * NUM_BANKS + shift * N + (j + 0))]); \ - v2h a01 = \ - (*(v2h *)&A_folded[2 * ((i + 0) * NUM_BANKS + shift * N + (j + 1))]); \ - v2h a10 = \ - (*(v2h *)&A_folded[2 * ((i + 1) * NUM_BANKS + shift * N + (j + 0))]); \ - v2h a11 = \ - (*(v2h *)&A_folded[2 * ((i + 1) * NUM_BANKS + shift * N + (j + 1))]); \ + v2h a00 = (*(v2h *)&A_fetch[2 * ((i + 0) * N + (j + 0))]); \ + v2h a01 = (*(v2h *)&A_fetch[2 * ((i + 0) * N + (j + 1))]); \ + v2h a10 = (*(v2h *)&A_fetch[2 * ((i + 1) * N + (j + 0))]); \ + v2h a11 = (*(v2h *)&A_fetch[2 * ((i + 1) * N + (j + 1))]); \ v2h b00 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 0))]); \ v2h b01 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 1))]); \ v2h b02 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 2))]); \ @@ -267,157 +283,39 @@ v2h b11 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 1))]); \ v2h b12 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 2))]); \ v2h b13 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 3))]); \ - v2h a00s, a01s, a10s, a11s; \ - asm volatile("pv.shuffle2.h %[a00s], %[a00], %[mask];" \ - "pv.shuffle2.h %[a01s], %[a01], %[mask];" \ - "pv.shuffle2.h %[a10s], %[a10], %[mask];" \ - "pv.shuffle2.h %[a11s], %[a11], %[mask];" \ - : [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \ - [a11s] "+r"(a11s) \ - : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \ - [a11] "r"(a11), [mask] "r"(0x00020003) \ - :); \ asm volatile( \ - "vfdotpex.s.h %[sum00_imag], %[a00s], %[b00];" \ - "vfdotpex.s.h %[sum10_imag], %[a10s], %[b00];" \ - "vfdotpex.s.h %[sum01_imag], %[a00s], %[b01];" \ - "vfdotpex.s.h %[sum11_imag], %[a10s], %[b01];" \ - "vfdotpex.s.h %[sum02_imag], %[a00s], %[b02];" \ - "vfdotpex.s.h %[sum12_imag], %[a10s], %[b02];" \ - "vfdotpex.s.h %[sum03_imag], %[a00s], %[b03];" \ - "vfdotpex.s.h %[sum13_imag], %[a10s], %[b03];" \ - "vfdotpex.s.h %[sum00_imag], %[a01s], %[b10];" \ - "vfdotpex.s.h %[sum10_imag], %[a11s], %[b10];" \ - "vfdotpex.s.h %[sum01_imag], %[a01s], %[b11];" \ - "vfdotpex.s.h %[sum11_imag], %[a11s], %[b11];" \ - "vfdotpex.s.h %[sum02_imag], %[a01s], %[b12];" \ - "vfdotpex.s.h %[sum12_imag], %[a11s], %[b12];" \ - "vfdotpex.s.h %[sum03_imag], %[a01s], %[b13];" \ - "vfdotpex.s.h %[sum13_imag], %[a11s], %[b13];" \ - : [sum00_imag] "=r"(sum00_imag), [sum01_imag] "=r"(sum01_imag), \ - [sum02_imag] "=r"(sum02_imag), [sum03_imag] "=r"(sum03_imag), \ - [sum10_imag] "=r"(sum10_imag), [sum11_imag] "=r"(sum11_imag), \ - [sum12_imag] "=r"(sum12_imag), [sum13_imag] "=r"(sum13_imag) \ - : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ - [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \ - [b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \ - [b13] "r"(b13) \ - :); \ - asm volatile("xor %[a00s], %[a00], %[mask];" \ - "xor %[a01s], %[a01], %[mask];" \ - "xor %[a10s], %[a10], %[mask];" \ - "xor %[a11s], %[a11], %[mask];" \ - : [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \ - [a11s] "+r"(a11s) \ - : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \ - [a11] "r"(a11), [mask] "r"(0x80000000) \ - :); \ - asm volatile( \ - "vfdotpex.s.h %[sum00_real], %[a00s], %[b00];" \ - "vfdotpex.s.h %[sum10_real], %[a10s], %[b00];" \ - "vfdotpex.s.h %[sum01_real], %[a00s], %[b01];" \ - "vfdotpex.s.h %[sum11_real], %[a10s], %[b01];" \ - "vfdotpex.s.h %[sum02_real], %[a00s], %[b02];" \ - "vfdotpex.s.h %[sum12_real], %[a10s], %[b02];" \ - "vfdotpex.s.h %[sum03_real], %[a00s], %[b03];" \ - "vfdotpex.s.h %[sum13_real], %[a10s], %[b03];" \ - "vfdotpex.s.h %[sum00_real], %[a01s], %[b10];" \ - "vfdotpex.s.h %[sum10_real], %[a11s], %[b10];" \ - "vfdotpex.s.h %[sum01_real], %[a01s], %[b11];" \ - "vfdotpex.s.h %[sum11_real], %[a11s], %[b11];" \ - "vfdotpex.s.h %[sum02_real], %[a01s], %[b12];" \ - "vfdotpex.s.h %[sum12_real], %[a11s], %[b12];" \ - "vfdotpex.s.h %[sum03_real], %[a01s], %[b13];" \ - "vfdotpex.s.h %[sum13_real], %[a11s], %[b13];" \ - : [sum00_real] "=r"(sum00_real), [sum01_real] "=r"(sum01_real), \ - [sum02_real] "=r"(sum02_real), [sum03_real] "=r"(sum03_real), \ - [sum10_real] "=r"(sum10_real), [sum11_real] "=r"(sum11_real), \ - [sum12_real] "=r"(sum12_real), [sum13_real] "=r"(sum13_real) \ - : [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \ - [a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \ - [b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \ - [b13] "r"(b13) \ + "fcdotpex.s.h %[sum00], %[a00], %[b00];" \ + "fcdotpex.s.h %[sum10], %[a10], %[b00];" \ + "fcdotpex.s.h %[sum01], %[a00], %[b01];" \ + "fcdotpex.s.h %[sum11], %[a10], %[b01];" \ + "fcdotpex.s.h %[sum02], %[a00], %[b02];" \ + "fcdotpex.s.h %[sum12], %[a10], %[b02];" \ + "fcdotpex.s.h %[sum03], %[a00], %[b03];" \ + "fcdotpex.s.h %[sum13], %[a10], %[b03];" \ + "fcdotpex.s.h %[sum00], %[a01], %[b10];" \ + "fcdotpex.s.h %[sum10], %[a11], %[b10];" \ + "fcdotpex.s.h %[sum01], %[a01], %[b11];" \ + "fcdotpex.s.h %[sum11], %[a11], %[b11];" \ + "fcdotpex.s.h %[sum02], %[a01], %[b12];" \ + "fcdotpex.s.h %[sum12], %[a11], %[b12];" \ + "fcdotpex.s.h %[sum03], %[a01], %[b13];" \ + "fcdotpex.s.h %[sum13], %[a11], %[b13];" \ + : [sum00] "+&r"(sum00), [sum01] "+&r"(sum01), [sum02] "+&r"(sum02), \ + [sum03] "+&r"(sum03), [sum10] "+&r"(sum10), [sum11] "+&r"(sum11), \ + [sum12] "+&r"(sum12), [sum13] "+&r"(sum13) \ + : [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), [a11] "r"(a11), \ + [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), [b03] "r"(b03), \ + [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), [b13] "r"(b13) \ :); \ } \ - v2h res00, res01, res02, res03; \ - v2h res10, res11, res12, res13; \ - asm volatile( \ - "vfcpka.h.s %[res00], %[sum00_real], %[sum00_imag];" \ - "vfcpka.h.s %[res01], %[sum01_real], %[sum01_imag];" \ - "vfcpka.h.s %[res02], %[sum02_real], %[sum02_imag];" \ - "vfcpka.h.s %[res03], %[sum03_real], %[sum03_imag];" \ - "vfcpka.h.s %[res10], %[sum10_real], %[sum10_imag];" \ - "vfcpka.h.s %[res11], %[sum11_real], %[sum11_imag];" \ - "vfcpka.h.s %[res12], %[sum12_real], %[sum12_imag];" \ - "vfcpka.h.s %[res13], %[sum13_real], %[sum13_imag];" \ - : [res00] "+r"(res00), [res01] "+r"(res01), [res02] "+r"(res02), \ - [res03] "+r"(res03), [res10] "+r"(res10), [res11] "+r"(res11), \ - [res12] "+r"(res12), [res13] "+r"(res13) \ - : [sum00_imag] "r"(sum00_imag), [sum01_imag] "r"(sum01_imag), \ - [sum02_imag] "r"(sum02_imag), [sum03_imag] "r"(sum03_imag), \ - [sum10_imag] "r"(sum10_imag), [sum11_imag] "r"(sum11_imag), \ - [sum12_imag] "r"(sum12_imag), [sum13_imag] "r"(sum13_imag), \ - [sum00_real] "r"(sum00_real), [sum01_real] "r"(sum01_real), \ - [sum02_real] "r"(sum02_real), [sum03_real] "r"(sum03_real), \ - [sum10_real] "r"(sum10_real), [sum11_real] "r"(sum11_real), \ - [sum12_real] "r"(sum12_real), [sum13_real] "r"(sum13_real) \ - :); \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 0)]) = res00; \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 1)]) = res01; \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 2)]) = res02; \ - (*(v2h *)&C[2 * ((i + 0) * P + k + 3)]) = res03; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 0)]) = res10; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 1)]) = res11; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 2)]) = res12; \ - (*(v2h *)&C[2 * ((i + 1) * P + k + 3)]) = res13; - -void cmatmul_f16s(__fp16 const *__restrict__ A, __fp16 const *__restrict__ B, - __fp16 *__restrict__ C, uint32_t M, uint32_t N, uint32_t P) { - uint32_t i = 0; // loop counter for M - uint32_t j = 0; // loop counter for N - uint32_t k = 0; // loop counter for P - for (k = 0; k < P; k++) { - for (i = 0; i < M; i++) { - float sum_real = 0.0f; - float sum_imag = 0.0f; - for (j = 0; j < N; j++) { - v2h a = (*(v2h *)&A[2 * (i * N + j)]); - v2h b = (*(v2h *)&B[2 * (j * P + k)]); - v2h as; - asm volatile("pv.shuffle2.h %[as], %[a], %[mask_shuffle];" - "vfdotpex.s.h %[sum_imag], %[as], %[b];" - "xor %[as], %[a], %[mask_sign];" - "vfdotpex.s.h %[sum_real], %[as], %[b];" - : [sum_real] "+&r"(sum_real), [sum_imag] "+&r"(sum_imag), - [as] "+&r"(as) - : [a] "r"(a), [b] "r"(b), [mask_shuffle] "r"(0x00020003), - [mask_sign] "r"(0x80000000) - :); - } - v2h res; - asm volatile("vfcpka.h.s %[res], %[sum_real], %[sum_imag];" - : [res] "+r"(res) - : [sum_real] "r"(sum_real), [sum_imag] "r"(sum_imag) - :); - (*(v2h *)&C[2 * (i * P + k)]) = res; - } - } - return; -} - -void cmatmul_2x4_f16s(__fp16 const *__restrict__ A, - __fp16 const *__restrict__ B, __fp16 *__restrict__ C, - uint32_t M, uint32_t N, uint32_t P) { - uint32_t i = 0; // loop counter for M - uint32_t j = 0; // loop counter for N - uint32_t k = 0; // loop counter for P - for (k = 0; k < P; k += 4) { - for (i = 0; i < M; i += 2) { - CMATMUL_2x4_LOOP; - } - } - return; -} + (*(v2h *)&C[(i + 0) * P + k + 0]) = sum00; \ + (*(v2h *)&C[(i + 0) * P + k + 1]) = sum01; \ + (*(v2h *)&C[(i + 0) * P + k + 2]) = sum02; \ + (*(v2h *)&C[(i + 0) * P + k + 3]) = sum03; \ + (*(v2h *)&C[(i + 1) * P + k + 0]) = sum10; \ + (*(v2h *)&C[(i + 1) * P + k + 1]) = sum11; \ + (*(v2h *)&C[(i + 1) * P + k + 2]) = sum12; \ + (*(v2h *)&C[(i + 1) * P + k + 3]) = sum13; void cmatmul_2x2_f16p(__fp16 const *__restrict__ A, __fp16 const *__restrict__ B, __fp16 *__restrict__ C, @@ -427,33 +325,34 @@ void cmatmul_2x2_f16p(__fp16 const *__restrict__ A, uint32_t i = 0; // loop counter for M uint32_t j = 0; // loop counter for N uint32_t k = 0; // loop counter for P - uint32_t shift_id = 2 * (core_id % NUM_CORES_PER_TILE); for (k = core_id * 2; k < P; k += 2 * numThreads) { - // dump_prova(k); - for (i = shift_id; i < M; i += 2) { - CMATMUL_2x2_LOOP; - } - for (i = 0; i < shift_id; i += 2) { + for (i = 0; i < M; i += 2) { CMATMUL_2x2_LOOP; } } + mempool_log_partial_barrier(2, core_id, numThreads); return; } -void cmatmul_2x4_f16p(__fp16 const *__restrict__ A, - __fp16 const *__restrict__ B, __fp16 *__restrict__ C, - uint32_t M, uint32_t N, uint32_t P, uint32_t core_id, - uint32_t numThreads) { +void cmatmul_2x4_f16p(__fp16 *__restrict__ A, __fp16 const *__restrict__ B, + __fp16 *__restrict__ C, uint32_t M, uint32_t N, + uint32_t P, uint32_t core_id, uint32_t numThreads) { uint32_t i = 0; // loop counter for M uint32_t j = 0; // loop counter for N uint32_t k = 0; // loop counter for P - // uint32_t shift_id = 2 * (core_id % NUM_CORES_PER_TILE); + uint32_t shift_id = 2 * (core_id % NUM_CORES_PER_TILE); + __fp16 *A_fetch = A; + for (k = core_id * 4; k < P; k += 4 * numThreads) { - for (i = 0; i < M; i += 2) { - CMATMUL_2x4_LOOP; + for (i = shift_id; i < M; i += 2) { + CMATMUL_CDOTP_2x4_LOOP; + } + for (i = 0; i < shift_id; i += 2) { + CMATMUL_CDOTP_2x4_LOOP; } } + mempool_log_partial_barrier(2, core_id, numThreads); return; } @@ -467,52 +366,60 @@ void cmatmul_2x4_folded_f16p(__fp16 const *__restrict__ A, uint32_t i = 0; // loop counter for M uint32_t j = 0; // loop counter for N uint32_t k = 0; // loop counter for P - uint32_t NUM_BANKS = NUM_CORES * 4; - // uint32_t core_id_shuffling = core_id / (N / 4); - // uint32_t numThreads_shuffling = numThreads / (N / 4); - // if ((core_id % (N / 4)) == 0) { - // // Loop over the rows with cores from different cluster locations - // for (i = core_id_shuffling; i < M; i += numThreads_shuffling) { - // for (j = 0; j < N; j += 4) { - // // Copy multiple A matrices in memory - // v2h a0 = (*(v2h *)&A[2 * (i * N + j)]); - // v2h a1 = (*(v2h *)&A[2 * (i * N + j + 1)]); - // v2h a2 = (*(v2h *)&A[2 * (i * N + j + 2)]); - // v2h a3 = (*(v2h *)&A[2 * (i * N + j + 3)]); - // for (k = 0; k < NUM_BANKS / N; k++) { - // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + k * N + j)]) = a0; - // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + k * N + j + 1)]) = a1; - // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + k * N + j + 2)]) = a2; - // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + k * N + j + 3)]) = a3; - // } - // } + // // Copy multiple A matrices in memory + // uint32_t num_copy = NUM_BANKS / N; + // for (k = core_id * 4; k < N * M; k += 4 * numThreads) { + // v2h a0 = (*(v2h *)&A[2 * k]); + // v2h a1 = (*(v2h *)&A[2 * (k + 1)]); + // v2h a2 = (*(v2h *)&A[2 * (k + 2)]); + // v2h a3 = (*(v2h *)&A[2 * (k + 3)]); + // i = k / N; // row_index + // j = k % N; // col_index + // for (uint32_t idx_copy = 0; idx_copy < num_copy; idx_copy++) { + // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + idx_copy * N)]) = a0; + // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + 1 + idx_copy * N)]) = a1; + // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + 2 + idx_copy * N)]) = a2; + // (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + 3 + idx_copy * N)]) = a3; // } // } + // v2h* A_fetch = A_folded + N * (core_id * 4) / N; + // mempool_barrier(NUM_CORES); + // for (k = core_id * 4; k < P; k += 4 * numThreads) { + // for (i = 0; i < M; i += 2) { + // CMATMUL_2x4_LOOP; + // } + // } + // mempool_barrier(NUM_CORES); // Copy multiple A matrices in memory - uint32_t num_copy = NUM_BANKS / N; + uint32_t num_copy = NUM_BANKS / (N * M); for (k = core_id * 4; k < N * M; k += 4 * numThreads) { - v2h a0 = (*(v2h *)&A[2 * k]); - v2h a1 = (*(v2h *)&A[2 * (k + 1)]); - v2h a2 = (*(v2h *)&A[2 * (k + 2)]); - v2h a3 = (*(v2h *)&A[2 * (k + 3)]); + v2h a0 = *(v2h *)&A[2 * k]; + v2h a1 = *(v2h *)&A[2 * (k + 1)]; + v2h a2 = *(v2h *)&A[2 * (k + 2)]; + v2h a3 = *(v2h *)&A[2 * (k + 3)]; i = k / N; // row_index j = k % N; // col_index for (uint32_t idx_copy = 0; idx_copy < num_copy; idx_copy++) { - (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + idx_copy * N)]) = a0; - (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + 1 + idx_copy * N)]) = a1; - (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + 2 + idx_copy * N)]) = a2; - (*(v2h *)&A_folded[2 * (i * NUM_BANKS + j + 3 + idx_copy * N)]) = a3; + (*(v2h *)&A_folded[2 * (idx_copy * N * M + i * N + j)]) = a0; + (*(v2h *)&A_folded[2 * (idx_copy * N * M + i * N + j + 1)]) = a1; + (*(v2h *)&A_folded[2 * (idx_copy * N * M + i * N + j + 2)]) = a2; + (*(v2h *)&A_folded[2 * (idx_copy * N * M + i * N + j + 3)]) = a3; } } - mempool_barrier(numThreads); - - uint32_t shift = core_id / N; + __fp16 *A_fetch = + A_folded + 2 * (N * M) * (core_id / ((N * M) / BANKING_FACTOR)); + mempool_barrier(NUM_CORES); + mempool_log_partial_barrier(2, core_id, numThreads); + // Compute for (k = core_id * 4; k < P; k += 4 * numThreads) { for (i = 0; i < M; i += 2) { - CMATMUL_2x4_LOOP_FOLDED; + CMATMUL_2x4_LOOP; + // CMATMUL_CDOTP_2x4_LOOP; } } + mempool_log_partial_barrier(2, core_id, numThreads); + return; }