Skip to content

Commit

Permalink
[software] Add complex instructions to fp16 cmatmul
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Jan 8, 2024
1 parent ede7d50 commit 1253aad
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 508 deletions.
43 changes: 24 additions & 19 deletions software/apps/cmatmul_f16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

// Author: Samuel Riedel, ETH Zurich
// Author: Marco Bertuletti, ETH Zurich

#include <stdint.h>
#include <string.h>
Expand All @@ -15,7 +15,9 @@
#include "data/data_cmatmul_f16.h"
#include "kernel/mempool_checks.h"
#include "kernel/mempool_cmatmul_f16.h"
#define PARALLEL_2x2
#define PARALLEL
#define LOOP_2x4
// #define TEST

__fp16 matrix_a[2 * dim_M * dim_N]
__attribute__((aligned(BANKING_FACTOR * NUM_CORES * sizeof(int32_t)),
Expand All @@ -38,8 +40,8 @@ int main() {

// Initialize Matrices
if (core_id == 0) {
dma_memcpy_blocking(matrix_a, A, dim_M * dim_N * sizeof(int32_t));
dma_memcpy_blocking(matrix_b, B, dim_N * dim_P * sizeof(int32_t));
dma_memcpy_blocking(matrix_a, A, 2 * dim_M * dim_N * sizeof(int16_t));
dma_memcpy_blocking(matrix_b, B, 2 * dim_N * dim_P * sizeof(int16_t));
}
// Wait at barrier until everyone is ready
mempool_barrier(num_cores);
Expand All @@ -48,40 +50,43 @@ int main() {
// Execute function to test.
if (core_id == 0) {
mempool_start_benchmark();
cmatmul_2x4_f16s(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P);
cmatmul_2x2_f16s(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P);
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
#endif

#if defined(PARALLEL_2x2)
#if defined(PARALLEL)
// Execute function to test.
uint32_t nPE = core_id < (dim_P / 2) ? num_cores : (dim_P / 2);

#if defined(LOOP_2x2)
uint32_t nPE = num_cores < (dim_P / 2) ? num_cores : (dim_P / 2); // 2x2
if (core_id < nPE) {
mempool_start_benchmark();
cmatmul_2x2_f16p(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P, core_id,
nPE);
mempool_log_partial_barrier(2, core_id, nPE);
cmatmul_2x2_f16p((v2h *)matrix_a, (v2h *)matrix_b, (v2h *)matrix_c, dim_M,
dim_N, dim_P, core_id, nPE);
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
#endif

#if defined(PARALLEL_2x4)
// Execute function to test.
uint32_t nPE = core_id < (dim_P / 4) ? num_cores : (dim_P / 4);
#if defined(LOOP_2x4)
uint32_t nPE = num_cores < (dim_P / 4) ? num_cores : (dim_P / 4); // 2x4
if (core_id < nPE) {
mempool_start_benchmark();
cmatmul_2x4_f16p(matrix_a, matrix_b, matrix_c, dim_M, dim_N, dim_P, core_id,
nPE);
mempool_log_partial_barrier(2, core_id, nPE);
// cmatmul_2x4_f16p(
// (v2h*) matrix_a,
// (v2h*) matrix_b,
// (v2h*) matrix_c,
// dim_M, dim_N, dim_P, core_id, nPE);
cmatmul_2x4_folded_f16p(matrix_a, matrix_b, matrix_a_folded, matrix_c,
dim_M, dim_N, dim_P, core_id, nPE);
mempool_stop_benchmark();
}
#endif
mempool_barrier(num_cores);
#endif

#if defined(TEST)
mempool_check_f32(matrix_c, C, 2 * dim_M * dim_P, 0.01f, 0);
mempool_check_f16(matrix_c, C, 2 * dim_M * dim_P, 0.1f, 0);
mempool_barrier(num_cores);
#endif

Expand Down
4 changes: 2 additions & 2 deletions software/apps/ofdm/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ int main() {

/* BEAMFORMING */
mempool_start_benchmark();
cmatmul_2x4_folded_f16p(l1_pBF_Coef_folded, l1_pFFT_Src, l1_pFFT_Dst, N_BEAMS,
N_RX, N_SC, core_id, num_cores);
cmatmul_2x4_folded_f16p(l1_pBF_Coef_folded, l1_pBF_Coef_folded, l1_pFFT_Src,
l1_pFFT_Dst, N_BEAMS, N_RX, N_SC, core_id, num_cores);
mempool_stop_benchmark();
dump_prova(2);

Expand Down
2 changes: 1 addition & 1 deletion software/runtime/data/data_cmatmul_f16.h.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
i = 0
out += '\n'
for a in array:
out += '(__fp16){:.5f}, '.format(a)
out += '(__fp16){:.4f}, '.format(a)
i += 1
if i % 8 == 0:
out += '\n'
Expand Down
217 changes: 0 additions & 217 deletions software/runtime/kernel/cmatmul_f16.h
Original file line number Diff line number Diff line change
@@ -1,217 +0,0 @@
// Copyright 2021 ETH Zurich and University of Bologna.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

// Author: Marco Bertuletti, ETH Zurich

/* This library implements the complex matrix multiplication in multiple
* different ways. The functions all follow the following format:
*
* A is an M x N matrix, B is a N x P matrix, and C is a M x P matrix
* C = AB
*/

#include "xpulp/builtins_v2.h"

#define CMATMUL_2x4_LOOP \
float sum00_real = 0.0f; \
float sum01_real = 0.0f; \
float sum02_real = 0.0f; \
float sum03_real = 0.0f; \
float sum10_real = 0.0f; \
float sum11_real = 0.0f; \
float sum12_real = 0.0f; \
float sum13_real = 0.0f; \
float sum00_imag = 0.0f; \
float sum01_imag = 0.0f; \
float sum02_imag = 0.0f; \
float sum03_imag = 0.0f; \
float sum10_imag = 0.0f; \
float sum11_imag = 0.0f; \
float sum12_imag = 0.0f; \
float sum13_imag = 0.0f; \
for (j = 0; j < N; j += 2) { \
v2h a00 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 0))]); \
v2h a01 = (*(v2h *)&A[2 * ((i + 0) * N + (j + 1))]); \
v2h a10 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 0))]); \
v2h a11 = (*(v2h *)&A[2 * ((i + 1) * N + (j + 1))]); \
v2h b00 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 0))]); \
v2h b01 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 1))]); \
v2h b02 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 2))]); \
v2h b03 = (*(v2h *)&B[2 * ((j + 0) * P + (k + 3))]); \
v2h b10 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 0))]); \
v2h b11 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 1))]); \
v2h b12 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 2))]); \
v2h b13 = (*(v2h *)&B[2 * ((j + 1) * P + (k + 3))]); \
v2h a00s, a01s, a10s, a11s; \
asm volatile("pv.shuffle2.h %[a00s], %[a00], %[mask];" \
"pv.shuffle2.h %[a01s], %[a01], %[mask];" \
"pv.shuffle2.h %[a10s], %[a10], %[mask];" \
"pv.shuffle2.h %[a11s], %[a11], %[mask];" \
: [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \
[a11s] "+r"(a11s) \
: [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \
[a11] "r"(a11), [mask] "r"(0x00020003) \
:); \
asm volatile( \
"vfdotpex.s.h %[sum00_imag], %[a00s], %[b00];" \
"vfdotpex.s.h %[sum10_imag], %[a10s], %[b00];" \
"vfdotpex.s.h %[sum01_imag], %[a00s], %[b01];" \
"vfdotpex.s.h %[sum11_imag], %[a10s], %[b01];" \
"vfdotpex.s.h %[sum02_imag], %[a00s], %[b02];" \
"vfdotpex.s.h %[sum12_imag], %[a10s], %[b02];" \
"vfdotpex.s.h %[sum03_imag], %[a00s], %[b03];" \
"vfdotpex.s.h %[sum13_imag], %[a10s], %[b03];" \
"vfdotpex.s.h %[sum00_imag], %[a01s], %[b10];" \
"vfdotpex.s.h %[sum10_imag], %[a11s], %[b10];" \
"vfdotpex.s.h %[sum01_imag], %[a01s], %[b11];" \
"vfdotpex.s.h %[sum11_imag], %[a11s], %[b11];" \
"vfdotpex.s.h %[sum02_imag], %[a01s], %[b12];" \
"vfdotpex.s.h %[sum12_imag], %[a11s], %[b12];" \
"vfdotpex.s.h %[sum03_imag], %[a01s], %[b13];" \
"vfdotpex.s.h %[sum13_imag], %[a11s], %[b13];" \
: [sum00_imag] "+&r"(sum00_imag), [sum01_imag] "+&r"(sum01_imag), \
[sum02_imag] "+&r"(sum02_imag), [sum03_imag] "+&r"(sum03_imag), \
[sum10_imag] "+&r"(sum10_imag), [sum11_imag] "+&r"(sum11_imag), \
[sum12_imag] "+&r"(sum12_imag), [sum13_imag] "+&r"(sum13_imag) \
: [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \
[a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \
[b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \
[b13] "r"(b13) \
:); \
asm volatile("xor %[a00s], %[a00], %[mask];" \
"xor %[a01s], %[a01], %[mask];" \
"xor %[a10s], %[a10], %[mask];" \
"xor %[a11s], %[a11], %[mask];" \
: [a00s] "+r"(a00s), [a01s] "+r"(a01s), [a10s] "+r"(a10s), \
[a11s] "+r"(a11s) \
: [a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), \
[a11] "r"(a11), [mask] "r"(0x80000000) \
:); \
asm volatile( \
"vfdotpex.s.h %[sum00_real], %[a00s], %[b00];" \
"vfdotpex.s.h %[sum10_real], %[a10s], %[b00];" \
"vfdotpex.s.h %[sum01_real], %[a00s], %[b01];" \
"vfdotpex.s.h %[sum11_real], %[a10s], %[b01];" \
"vfdotpex.s.h %[sum02_real], %[a00s], %[b02];" \
"vfdotpex.s.h %[sum12_real], %[a10s], %[b02];" \
"vfdotpex.s.h %[sum03_real], %[a00s], %[b03];" \
"vfdotpex.s.h %[sum13_real], %[a10s], %[b03];" \
"vfdotpex.s.h %[sum00_real], %[a01s], %[b10];" \
"vfdotpex.s.h %[sum10_real], %[a11s], %[b10];" \
"vfdotpex.s.h %[sum01_real], %[a01s], %[b11];" \
"vfdotpex.s.h %[sum11_real], %[a11s], %[b11];" \
"vfdotpex.s.h %[sum02_real], %[a01s], %[b12];" \
"vfdotpex.s.h %[sum12_real], %[a11s], %[b12];" \
"vfdotpex.s.h %[sum03_real], %[a01s], %[b13];" \
"vfdotpex.s.h %[sum13_real], %[a11s], %[b13];" \
: [sum00_real] "+&r"(sum00_real), [sum01_real] "+&r"(sum01_real), \
[sum02_real] "+&r"(sum02_real), [sum03_real] "+&r"(sum03_real), \
[sum10_real] "+&r"(sum10_real), [sum11_real] "+&r"(sum11_real), \
[sum12_real] "+&r"(sum12_real), [sum13_real] "+&r"(sum13_real) \
: [a00s] "r"(a00s), [a01s] "r"(a01s), [a10s] "r"(a10s), \
[a11s] "r"(a11s), [b00] "r"(b00), [b01] "r"(b01), [b02] "r"(b02), \
[b03] "r"(b03), [b10] "r"(b10), [b11] "r"(b11), [b12] "r"(b12), \
[b13] "r"(b13) \
:); \
} \
v2h res00, res01, res02, res03; \
v2h res10, res11, res12, res13; \
asm volatile( \
"vfcpka.h.s %[res00], %[sum00_real], %[sum00_imag];" \
"vfcpka.h.s %[res01], %[sum01_real], %[sum01_imag];" \
"vfcpka.h.s %[res02], %[sum02_real], %[sum02_imag];" \
"vfcpka.h.s %[res03], %[sum03_real], %[sum03_imag];" \
"vfcpka.h.s %[res10], %[sum10_real], %[sum10_imag];" \
"vfcpka.h.s %[res11], %[sum11_real], %[sum11_imag];" \
"vfcpka.h.s %[res12], %[sum12_real], %[sum12_imag];" \
"vfcpka.h.s %[res13], %[sum13_real], %[sum13_imag];" \
: [res00] "+r"(res00), [res01] "+r"(res01), [res02] "+r"(res02), \
[res03] "+r"(res03), [res10] "+r"(res10), [res11] "+r"(res11), \
[res12] "+r"(res12), [res13] "+r"(res13) \
: [sum00_imag] "r"(sum00_imag), [sum01_imag] "r"(sum01_imag), \
[sum02_imag] "r"(sum02_imag), [sum03_imag] "r"(sum03_imag), \
[sum10_imag] "r"(sum10_imag), [sum11_imag] "r"(sum11_imag), \
[sum12_imag] "r"(sum12_imag), [sum13_imag] "r"(sum13_imag), \
[sum00_real] "r"(sum00_real), [sum01_real] "r"(sum01_real), \
[sum02_real] "r"(sum02_real), [sum03_real] "r"(sum03_real), \
[sum10_real] "r"(sum10_real), [sum11_real] "r"(sum11_real), \
[sum12_real] "r"(sum12_real), [sum13_real] "r"(sum13_real) \
:); \
(*(v2h *)&C[2 * ((i + 0) * P + k + 0)]) = res00; \
(*(v2h *)&C[2 * ((i + 0) * P + k + 1)]) = res01; \
(*(v2h *)&C[2 * ((i + 0) * P + k + 2)]) = res02; \
(*(v2h *)&C[2 * ((i + 0) * P + k + 3)]) = res03; \
(*(v2h *)&C[2 * ((i + 1) * P + k + 0)]) = res10; \
(*(v2h *)&C[2 * ((i + 1) * P + k + 1)]) = res11; \
(*(v2h *)&C[2 * ((i + 1) * P + k + 2)]) = res12; \
(*(v2h *)&C[2 * ((i + 1) * P + k + 3)]) = res13;

void cmatmul_f16s(__fp16 const *__restrict__ A, __fp16 const *__restrict__ B,
__fp16 *__restrict__ C, uint32_t M, uint32_t N, uint32_t P) {
uint32_t i = 0; // loop counter for M
uint32_t j = 0; // loop counter for N
uint32_t k = 0; // loop counter for P
for (k = 0; k < P; k++) {
for (i = 0; i < M; i++) {
float sum_real = 0.0f;
float sum_imag = 0.0f;
for (j = 0; j < N; j++) {
v2h a = (*(v2h *)&A[2 * (i * N + j)]);
v2h b = (*(v2h *)&B[2 * (j * P + k)]);
v2h as;
asm volatile("pv.shuffle2.h %[as], %[a], %[mask_shuffle];"
"vfdotpex.s.h %[sum_imag], %[as], %[b];"
"xor %[as], %[a], %[mask_sign];"
"vfdotpex.s.h %[sum_real], %[as], %[b];"
: [sum_real] "+&r"(sum_real), [sum_imag] "+&r"(sum_imag),
[as] "+&r"(as)
: [a] "r"(a), [b] "r"(b), [mask_shuffle] "r"(0x00020003),
[mask_sign] "r"(0x80000000)
:);
}
v2h res;
asm volatile("vfcpka.h.s %[res], %[sum_real], %[sum_imag];"
: [res] "+r"(res)
: [sum_real] "r"(sum_real), [sum_imag] "r"(sum_imag)
:);
(*(v2h *)&C[2 * (i * P + k)]) = res;
}
}
return;
}

void cmatmul_2x4_f16s(__fp16 const *__restrict__ A,
__fp16 const *__restrict__ B, __fp16 *__restrict__ C,
uint32_t M, uint32_t N, uint32_t P) {
uint32_t i = 0; // loop counter for M
uint32_t j = 0; // loop counter for N
uint32_t k = 0; // loop counter for P
for (k = 0; k < P; k += 4) {
for (i = 0; i < M; i += 2) {
CMATMUL_2x4_LOOP;
}
}
return;
}

void cmatmul_2x4_f16p(__fp16 const *__restrict__ A,
__fp16 const *__restrict__ B, __fp16 *__restrict__ C,
uint32_t M, uint32_t N, uint32_t P, uint32_t core_id,
uint32_t numThreads) {

uint32_t i = 0; // loop counter for M
uint32_t j = 0; // loop counter for N
uint32_t k = 0; // loop counter for P

uint32_t shift_id = core_id % (M / 2);
for (k = core_id * 4; k < P; k += 4 * numThreads) {
for (i = 0; i < shift_id; i += 2) {
CMATMUL_2x4_LOOP;
}
for (i = shift_id; i < M; i += 2) {
CMATMUL_2x4_LOOP;
}
}
return;
}
12 changes: 6 additions & 6 deletions software/runtime/kernel/mempool_checks.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void mempool_check_q32(int32_t *__restrict__ pRes, int32_t *__restrict__ pExp,
int32_t exp = pExp[i];
int32_t res = pRes[i];
error = exp - res;
bool print = ((error > TOL) || (error < (-TOL))) | verbose;
bool print = ((error > TOL) || (error < (-TOL))) || verbose;
if (print) {
printf("CHECK(%d): EXP = %x - RESP = %x\n", i, exp, res);
ERRORS++;
Expand All @@ -48,9 +48,9 @@ void mempool_check_q16(int16_t *__restrict__ pRes, int16_t *__restrict__ pExp,
if (core_id == 0) {
uint32_t ERRORS = 0;
for (uint32_t i = 0; i < NEL; i++) {
int16_t exp = (int16_t) pExp[i];
int16_t res = (int16_t) pRes[i];
error = (int16_t) (exp - res);
int16_t exp = (int16_t)pExp[i];
int16_t res = (int16_t)pRes[i];
error = (int16_t)(exp - res);
bool print = ((error > TOL) || (error < (-TOL))) | verbose;
if (print) {
printf("CHECK(%d): EXP = %x - RESP = %x\n", i, exp, res);
Expand Down Expand Up @@ -84,7 +84,7 @@ void mempool_check_f32(float *__restrict__ pRes, float *__restrict__ pExp,
: [error] "+&r"(error)
: [res] "r"(res), [exp] "r"(exp)
:);
bool print = ((error > TOL) || (error < (-TOL))) | verbose;
bool print = ((error > TOL) || (error < (-TOL))) || verbose;
if (print) {
printf("CHECK(%d): EXP = %x - RESP = %x\n", i, exp, res);
ERRORS++;
Expand Down Expand Up @@ -117,7 +117,7 @@ void mempool_check_f16(__fp16 *__restrict__ pRes, __fp16 *__restrict__ pExp,
: [error] "+&r"(error)
: [res] "r"(res), [exp] "r"(exp)
:);
bool print = ((error > TOL) || (error < (-TOL))) | verbose;
bool print = ((error > TOL) || (error < (-TOL))) || verbose;
if (print) {
printf("CHECK(%d): EXP = %x - RESP = %x\n", i, *(int32_t *)&exp,
*(int32_t *)&res);
Expand Down
Loading

0 comments on commit 1253aad

Please sign in to comment.