Skip to content

Commit

Permalink
Refactored her[2]k/syr[2]k in terms of gemmt. (#531)
Browse files Browse the repository at this point in the history
Details:
- Renamed herk macrokernels and supporting files and functions to gemmt,
  which is possible since at the macrokernel level they are identical.
  Then recast herk/her2k/syrk/syr2k in terms of gemmt within the expert
  level-3 oapi (bli_l3_oapi_ex.c) while also redefining them as literal
  functions rather than cpp macros that instantiate multiple functions.
  Thanks to Devin Matthews for his efforts on this issue (#531).
- Check that the maximum stack buffer size is sufficiently large
  relative to the register blocksizes for each datatype, and do so when
  the context is initialized rather than when an operation is called.
  Note that with this change, users who pass in their own contexts into
  the expert interfaces currently will *not* have any checks performed.
  Thanks to Devin Matthews for suggesting this change.
- (cherry picked from commit 28b0982)
  • Loading branch information
fgvanzee committed Sep 10, 2022
1 parent a538807 commit 1652baf
Show file tree
Hide file tree
Showing 52 changed files with 647 additions and 4,396 deletions.
4 changes: 2 additions & 2 deletions config/zen/bli_family_zen.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@

#define BLIS_SMALL_MATRIX_THRES_TRSM 32768 //128(128+128) => m*(m+n)
#define BLIS_SMALL_MATRIX_A_THRES_TRSM 128
#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96
#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128
#define BLIS_SMALL_MATRIX_A_THRES_M_GEMMT 96
#define BLIS_SMALL_MATRIX_A_THRES_N_GEMMT 128

//This macro will enable BLIS DGEMM to choose block sizes for a single instance mode
#define BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES 0
Expand Down
4 changes: 2 additions & 2 deletions config/zen2/bli_family_zen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@

#define BLIS_SMALL_MATRIX_THRES_TRSM 32768 //128(128+128) => m*(m+n)
#define BLIS_SMALL_MATRIX_A_THRES_TRSM 128
#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96
#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128
#define BLIS_SMALL_MATRIX_A_THRES_M_GEMMT 96
#define BLIS_SMALL_MATRIX_A_THRES_N_GEMMT 128

#define BLIS_ENABLE_SMALL_MATRIX_ROME
#define BLIS_SMALL_MATRIX_THRES_ROME 400
Expand Down
4 changes: 0 additions & 4 deletions frame/3/bli_l3.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@
// Operation-specific headers.
#include "bli_gemm.h"
#include "bli_hemm.h"
#include "bli_herk.h"
#include "bli_her2k.h"
#include "bli_symm.h"
#include "bli_syrk.h"
#include "bli_syr2k.h"
#include "bli_trmm.h"
#include "bli_trmm3.h"
#include "bli_trsm.h"
Expand Down
12 changes: 6 additions & 6 deletions frame/3/bli_l3_blocksize.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ dim_t bli_l3_determine_kc

if ( family == BLIS_GEMM )
return bli_gemm_determine_kc( direct, i, dim, a, b, bszid, cntx );
else if ( family == BLIS_HERK )
return bli_herk_determine_kc( direct, i, dim, a, b, bszid, cntx );
else if ( family == BLIS_GEMMT )
return bli_gemmt_determine_kc( direct, i, dim, a, b, bszid, cntx );
else if ( family == BLIS_TRMM )
return bli_trmm_determine_kc( direct, i, dim, a, b, bszid, cntx );
else if ( family == BLIS_TRSM )
Expand Down Expand Up @@ -91,7 +91,7 @@ dim_t PASTEMAC0(opname) \
}

GENFRONT( gemm_determine_kc, gemm )
GENFRONT( herk_determine_kc, herk )
GENFRONT( gemmt_determine_kc, gemmt )
GENFRONT( trmm_determine_kc, trmm )
GENFRONT( trsm_determine_kc, trsm )

Expand Down Expand Up @@ -201,7 +201,7 @@ dim_t PASTEMAC0(opname) \
b_alg = bli_blksz_get_def( dt, bsize ); \
b_max = bli_blksz_get_max( dt, bsize ); \
\
/* Notice that for herk, we do not need to perform any special handling
/* Notice that for gemmt, we do not need to perform any special handling
for the default and maximum kc blocksizes vis-a-vis MR or NR. */ \
\
/* Call the bli_determine_blocksize_[fb]_sub() helper routine defined
Expand All @@ -211,8 +211,8 @@ dim_t PASTEMAC0(opname) \
return b_use; \
}

GENFRONT( herk_determine_kc_f, f )
GENFRONT( herk_determine_kc_b, b )
GENFRONT( gemmt_determine_kc_f, f )
GENFRONT( gemmt_determine_kc_b, b )

// -----------------------------------------------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions frame/3/bli_l3_blocksize.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dim_t PASTEMAC0(opname) \
);

GENPROT( gemm_determine_kc )
GENPROT( herk_determine_kc )
GENPROT( gemmt_determine_kc )
GENPROT( trmm_determine_kc )
GENPROT( trsm_determine_kc )

Expand All @@ -81,8 +81,8 @@ dim_t PASTEMAC0(opname) \
GENPROT( gemm_determine_kc_f )
GENPROT( gemm_determine_kc_b )

GENPROT( herk_determine_kc_f )
GENPROT( herk_determine_kc_b )
GENPROT( gemmt_determine_kc_f )
GENPROT( gemmt_determine_kc_b )

GENPROT( trmm_determine_kc_f )
GENPROT( trmm_determine_kc_b )
Expand Down
5 changes: 0 additions & 5 deletions frame/3/bli_l3_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,5 @@ void bli_l3_basic_check

e_val = bli_check_object_buffer( c );
bli_check_error_code( e_val );

// Check for sufficiently sized stack buffers

e_val = bli_check_sufficient_stack_buf_size( bli_obj_dt( a ), cntx );
bli_check_error_code( e_val );
}

4 changes: 2 additions & 2 deletions frame/3/bli_l3_cntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void bli_l3_cntl_create_if
if ( cntl_orig == NULL )
{
if ( family == BLIS_GEMM ||
family == BLIS_HERK ||
family == BLIS_GEMMT ||
family == BLIS_TRMM )
{
*cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b );
Expand Down Expand Up @@ -97,7 +97,7 @@ void bli_l3_cntl_free
opid_t family = bli_cntl_family( cntl_use );

if ( family == BLIS_GEMM ||
family == BLIS_HERK ||
family == BLIS_GEMMT ||
family == BLIS_TRMM )
{
bli_gemm_cntl_free( rntm, cntl_use, thread );
Expand Down
6 changes: 3 additions & 3 deletions frame/3/bli_l3_direct.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dir_t bli_l3_direct
opid_t family = bli_cntl_family( cntl );

if ( family == BLIS_GEMM ) return bli_gemm_direct( a, b, c );
else if ( family == BLIS_HERK ) return bli_herk_direct( a, b, c );
else if ( family == BLIS_GEMMT ) return bli_gemmt_direct( a, b, c );
else if ( family == BLIS_TRMM ) return bli_trmm_direct( a, b, c );
else if ( family == BLIS_TRSM ) return bli_trsm_direct( a, b, c );

Expand All @@ -68,14 +68,14 @@ dir_t bli_gemm_direct
return BLIS_FWD;
}

dir_t bli_herk_direct
dir_t bli_gemmt_direct
(
obj_t* a,
obj_t* b,
obj_t* c
)
{
// For herk, movement may be forwards (or backwards).
// For gemmt, movement may be forwards (or backwards).

return BLIS_FWD;
}
Expand Down
2 changes: 1 addition & 1 deletion frame/3/bli_l3_direct.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dir_t PASTEMAC0(opname) \
);

GENPROT( gemm_direct )
GENPROT( herk_direct )
GENPROT( gemmt_direct )
GENPROT( trmm_direct )
GENPROT( trsm_direct )

7 changes: 2 additions & 5 deletions frame/3/bli_l3_ind.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ static bool bli_l3_ind_oper_impl[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] =
static BLIS_THREAD_LOCAL
bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] =
{
/* gemm gemmt hemm herk her2k symm syrk syr2k trmm3 trmm trsm */
/* gemm gemmt hemm herk her2k symm
syrk syr2k trmm3 trmm trsm */
/* c z */
/* 1m */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },
Expand All @@ -80,11 +81,7 @@ ind_t PASTEMAC(opname,ind_find_avail)( num_t dt ) \
GENFUNC( gemm, BLIS_GEMM )
GENFUNC( gemmt, BLIS_GEMMT )
GENFUNC( hemm, BLIS_HEMM )
GENFUNC( herk, BLIS_HERK )
GENFUNC( her2k, BLIS_HER2K )
GENFUNC( symm, BLIS_SYMM )
GENFUNC( syrk, BLIS_SYRK )
GENFUNC( syr2k, BLIS_SYR2K )
GENFUNC( trmm3, BLIS_TRMM3 )
GENFUNC( trmm, BLIS_TRMM )
GENFUNC( trsm, BLIS_TRSM )
Expand Down
4 changes: 0 additions & 4 deletions frame/3/bli_l3_ind.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ ind_t PASTEMAC(opname,ind_find_avail)( num_t dt );
GENPROT( gemm )
GENPROT( gemmt )
GENPROT( hemm )
GENPROT( herk )
GENPROT( her2k )
GENPROT( symm )
GENPROT( syrk )
GENPROT( syr2k )
GENPROT( trmm3 )
GENPROT( trmm )
GENPROT( trsm )
Expand Down
Loading

0 comments on commit 1652baf

Please sign in to comment.